Skip to content

WebSocket Server

The datadivr server provides a FastAPI-based WebSocket server that handles client connections, message routing, and event handling.

Basic Usage

import uvicorn
from datadivr import app, HandlerType, websocket_handler, WebSocketMessage
from datadivr.transport.messages import create_error_message

# Define handlers
@websocket_handler("sum_event", HandlerType.SERVER)
async def sum_handler(message: WebSocketMessage) -> WebSocketMessage:
    """Calculate sum of numbers in the payload."""
    try:
        numbers = message.payload.get("numbers")
        if not isinstance(numbers, list):
            return create_error_message(
                "Payload must contain a list of numbers",
                message.from_id
            )

        result = sum(float(n) for n in numbers)
        return WebSocketMessage(
            event_name="sum_handler_result",
            payload=result,
            to=message.from_id,
        )
    except Exception as e:
        return create_error_message(f"Error: {e}", message.from_id)

# Start the server
if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8765)

Built-in Handlers

The server comes with built-in handlers for common operations:

Sum Handler

@websocket_handler("sum_event", HandlerType.SERVER)
async def sum_handler(message: WebSocketMessage) -> WebSocketMessage:
    """Calculate sum of numbers in the payload."""
    try:
        numbers = message.payload.get("numbers")
        result = sum(float(n) for n in numbers)
        return WebSocketMessage(
            event_name="sum_handler_result",
            payload=result,
            to=message.from_id,
        )
    except Exception as e:
        return create_error_message(f"Error: {e}", message.from_id)

Server Implementation

The server uses FastAPI and maintains a registry of connected clients:

# Module-level state
clients: dict[WebSocket, str] = {}

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
    """Handle incoming WebSocket connections."""
    await handle_connection(websocket)

async def handle_connection(websocket: WebSocket) -> None:
    """Manage client connection lifecycle."""
    await websocket.accept()
    client_id = str(uuid.uuid4())
    clients[websocket] = client_id

    try:
        while True:
            data = await websocket.receive_json()
            message = WebSocketMessage.model_validate(data)
            message.from_id = client_id
            response = await handle_msg(message)
            if response is not None:
                await broadcast(response, websocket)
    except WebSocketDisconnect:
        del clients[websocket]

Message Broadcasting

The server supports three broadcasting modes:

  1. All Clients:
message = WebSocketMessage(
    event_name="announcement",
    message="Server maintenance in 5 minutes",
    to="all"
)
  1. Other Clients:
message = WebSocketMessage(
    event_name="user_joined",
    message="New user connected",
    to="others"
)
  1. Specific Client:
message = WebSocketMessage(
    event_name="private_message",
    message="Your request was processed",
    to="client_123"
)

Error Handling

The server handles various error conditions:

  • Invalid message formats
  • Client disconnections
  • Message broadcasting failures

All errors are logged using structured logging via structlog:

try:
    message = WebSocketMessage.model_validate(data)
except ValueError as e:
    logger.exception("invalid_message_format",
                    error=str(e),
                    client_id=client_id)
    raise InvalidMessageFormat()

Reference

datadivr.transport.server

WebSocket server implementation for datadivr.

This module provides a FastAPI-based WebSocket server that handles client connections, message routing, and event handling.

Example
import uvicorn
from datadivr import app

uvicorn.run(app, host="127.0.0.1", port=8765)

Classes

Functions

add_client(websocket)

Add a new client and return its client ID.

Source code in datadivr/transport/server.py
def add_client(websocket: WebSocket) -> str:
    """Add a new client and return its client ID."""
    client_id = str(uuid.uuid4())
    clients[client_id] = {"websocket": websocket, "state": {}}
    logger.info("client_connected", client_id=client_id, connected_clients=len(clients))
    return client_id

broadcast(message, sender) async

Broadcast a message to appropriate clients.

Source code in datadivr/transport/server.py
async def broadcast(message: WebSocketMessage, sender: WebSocket) -> None:
    """Broadcast a message to appropriate clients."""
    message_data = message.model_dump()
    targets: list[WebSocket] = []

    if message.to == "all":
        targets = [data["websocket"] for data in clients.values()]
    elif message.to == "others":
        targets = [data["websocket"] for cid, data in clients.items() if data["websocket"] != sender]
    else:
        target_data = next((data for cid, data in clients.items() if cid == message.to), None)
        if target_data:
            targets = [target_data["websocket"]]

    logger.debug("broadcasting_message", message=message_data, num_targets=len(targets))

    for websocket in targets:
        try:
            # Find client_id for this websocket
            client_id = next(cid for cid, data in clients.items() if data["websocket"] == websocket)
            await websocket.send_json(message_data)
            logger.debug("message_sent", client_id=client_id)
        except Exception as e:
            # Find client_id for this websocket
            client_id = next(cid for cid, data in clients.items() if data["websocket"] == websocket)
            logger.exception("broadcast_error", error=str(e), client_id=client_id)

close_client_connection(client_id) async

Close a client connection.

Source code in datadivr/transport/server.py
async def close_client_connection(client_id: str) -> None:
    """Close a client connection."""
    if client_id in clients:
        del clients[client_id]

get_client_state(client_id)

Retrieve the state information for a client by client ID.

Source code in datadivr/transport/server.py
def get_client_state(client_id: str) -> dict[str, Any] | None:
    """Retrieve the state information for a client by client ID."""
    return clients.get(client_id, {}).get("state")

handle_connection(websocket) async

Handle a WebSocket connection lifecycle.

Source code in datadivr/transport/server.py
@BackgroundTasks.task()
async def handle_connection(websocket: WebSocket) -> None:
    """Handle a WebSocket connection lifecycle."""
    await websocket.accept()
    client_id = add_client(websocket)

    try:
        while True:
            data = await websocket.receive_json()
            try:
                message = WebSocketMessage.model_validate(data)
                message.from_id = client_id
                response = await handle_msg(message)
                if response is not None:
                    await broadcast(response, websocket)
            except ValueError as e:
                logger.exception("invalid_message_format", error=str(e), client_id=client_id)
                raise InvalidMessageFormat() from None
    except WebSocketDisconnect:
        remove_client(client_id)
    except Exception as e:
        logger.exception("websocket_error", error=str(e), client_id=client_id)
        raise

handle_msg(message) async

Handle an incoming WebSocket message.

Source code in datadivr/transport/server.py
async def handle_msg(message: WebSocketMessage) -> WebSocketMessage | None:
    """Handle an incoming WebSocket message."""
    logger.debug("message_received", message=message.model_dump())

    handlers = get_handlers(HandlerType.SERVER)
    if message.event_name in handlers:
        logger.info("handling_event", event_name=message.event_name)
        return await handlers[message.event_name](message)
    return message

lifespan(app) async

Handle startup and shutdown events.

Source code in datadivr/transport/server.py
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    """Handle startup and shutdown events."""
    logger.debug("startup_initiated")

    server_handlers = get_handlers(HandlerType.SERVER)
    logger.info("registered_server_handlers", handlers=list(server_handlers.keys()))

    await BackgroundTasks.start_all()
    try:
        yield
    finally:
        logger.debug("shutdown_initiated", num_clients=len(clients))

        for client_id in list(clients.keys()):
            try:
                await close_client_connection(client_id)
                logger.debug("closed_client_connection", client_id=client_id)
            except Exception as e:
                logger.exception("client_close_error", error=str(e), client_id=client_id)

        await BackgroundTasks.stop_all()
        clients.clear()
        logger.debug("shutdown_completed")

remove_client(client_id)

Remove a client by its ID.

Source code in datadivr/transport/server.py
def remove_client(client_id: str) -> None:
    """Remove a client by its ID."""
    if client_id in clients:
        del clients[client_id]
        logger.info("client_disconnected", client_id=client_id)

update_client_state(client_id, **kwargs)

Update the state information for a client.

Source code in datadivr/transport/server.py
def update_client_state(client_id: str, **kwargs: Any) -> None:
    """Update the state information for a client."""
    if client_id in clients:
        clients[client_id]["state"].update(kwargs)

websocket_endpoint(websocket) async

Handle incoming WebSocket connections.

Source code in datadivr/transport/server.py
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
    """Handle incoming WebSocket connections."""
    await BackgroundTasks.task(name=f"ws_connection_{id(websocket)}")(handle_connection)(websocket)

options: show_root_heading: true show_source: true