"""WebSocket connection manager — multi-client broadcast for Atlus.""" import asyncio import json import logging from typing import Any from fastapi import WebSocket log = logging.getLogger("atlus.ws") class ConnectionManager: """Track active WebSocket connections and broadcast messages.""" def __init__(self) -> None: self._connections: dict[str, set[WebSocket]] = {} # channel -> sockets async def connect(self, websocket: WebSocket, channel: str = "stats") -> None: await websocket.accept() self._connections.setdefault(channel, set()).add(websocket) log.info("WS connect: channel=%s total=%d", channel, len(self._connections[channel])) def disconnect(self, websocket: WebSocket, channel: str = "stats") -> None: sockets = self._connections.get(channel) if sockets: sockets.discard(websocket) log.info("WS disconnect: channel=%s total=%d", channel, len(sockets)) async def broadcast(self, data: Any, channel: str = "stats") -> None: """Send JSON data to all clients on a channel.""" sockets = self._connections.get(channel) if not sockets: return payload = json.dumps(data) if not isinstance(data, str) else data stale: list[WebSocket] = [] for ws in sockets: try: await ws.send_text(payload) except Exception: stale.append(ws) for ws in stale: sockets.discard(ws) async def send_personal(self, websocket: WebSocket, data: Any) -> None: payload = json.dumps(data) if not isinstance(data, str) else data await websocket.send_text(payload) @property def active_count(self) -> int: return sum(len(s) for s in self._connections.values()) # Singleton shared across the app manager = ConnectionManager()