Source code for debug_toolbar.core.panels.websocket

"""WebSocket debugging panel for tracking WebSocket connections and messages."""

from __future__ import annotations

import asyncio
import json
import threading
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from uuid import uuid4

from debug_toolbar.core.panel import Panel

if TYPE_CHECKING:
    from debug_toolbar.core.context import RequestContext

__all__ = ["WebSocketMessage", "WebSocketConnection", "WebSocketPanel"]

EventCallback = Callable[[dict[str, Any]], None]


@dataclass(slots=True)
class WebSocketMessage:
    """Represents a single WebSocket message (sent or received).

    Attributes:
        direction: Whether the message was sent or received.
        message_type: The type of WebSocket frame.
        content: The message payload (text, bytes, or None for control frames).
        timestamp: Unix timestamp in seconds with millisecond precision.
        size_bytes: Size of the message payload in bytes.
        truncated: Whether the content was truncated to meet size limits.
    """

    direction: Literal["sent", "received"]
    message_type: Literal["text", "binary", "ping", "pong", "close"]
    content: str | bytes | None
    timestamp: float
    size_bytes: int
    truncated: bool = False

    def get_content_preview(self, max_length: int = 100) -> str:
        """Get a preview of the message content.

        Args:
            max_length: Maximum length of the preview string.

        Returns:
            A preview string, truncated if necessary. Binary data is shown as hex.
        """
        if self.content is None:
            return "<no content>"

        if isinstance(self.content, bytes):
            return self.get_binary_preview_hex(max_bytes=max_length // 2)

        preview = str(self.content)
        if len(preview) > max_length:
            return preview[:max_length] + "..."
        return preview

    def get_binary_preview_hex(self, max_bytes: int = 16) -> str:
        """Get a hexadecimal preview of binary message content.

        Args:
            max_bytes: Maximum number of bytes to display.

        Returns:
            Hexadecimal representation of the binary data.
        """
        if not isinstance(self.content, bytes):
            return "<not binary>"

        preview_bytes = self.content[:max_bytes]
        hex_str = preview_bytes.hex(" ", 1)

        if len(self.content) > max_bytes:
            return f"{hex_str}... ({len(self.content)} bytes total)"
        return hex_str


@dataclass(slots=True)
class WebSocketConnection:
    """Represents a WebSocket connection with its lifecycle and messages.

    Attributes:
        connection_id: Unique identifier for the connection (UUID).
        path: WebSocket endpoint path.
        query_string: Query string from the connection request.
        headers: HTTP headers from the WebSocket handshake.
        connected_at: Unix timestamp when connection was established.
        disconnected_at: Unix timestamp when connection was closed (None if still open).
        close_code: WebSocket close code (RFC 6455).
        close_reason: Human-readable close reason.
        messages: List of messages exchanged on this connection.
        state: Current connection state.
        total_sent: Count of messages sent.
        total_received: Count of messages received.
        bytes_sent: Total bytes sent.
        bytes_received: Total bytes received.
        messages_dropped: Count of messages dropped due to buffer limits.
    """

    connection_id: str
    path: str
    query_string: str
    headers: dict[str, str]
    connected_at: float
    disconnected_at: float | None = None
    close_code: int | None = None
    close_reason: str | None = None
    messages: list[WebSocketMessage] = field(default_factory=list)
    state: Literal["connecting", "connected", "closing", "closed"] = "connecting"
    total_sent: int = 0
    total_received: int = 0
    bytes_sent: int = 0
    bytes_received: int = 0
    messages_dropped: int = 0

    @staticmethod
    def generate_id() -> str:
        """Generate a unique connection ID.

        Returns:
            A UUID string for the connection.
        """
        return str(uuid4())

    def get_duration(self) -> float:
        """Calculate the connection duration in seconds.

        Returns:
            Duration in seconds. Returns time since connection if still open.
        """
        import time

        end_time = self.disconnected_at if self.disconnected_at is not None else time.time()
        return end_time - self.connected_at

    def add_message(self, message: WebSocketMessage, max_messages: int = 100) -> None:
        """Add a message to the connection, enforcing message buffer limits.

        Args:
            message: The WebSocket message to add.
            max_messages: Maximum number of messages to retain per connection.
        """
        self.messages.append(message)

        if message.direction == "sent":
            self.total_sent += 1
            self.bytes_sent += message.size_bytes
        else:
            self.total_received += 1
            self.bytes_received += message.size_bytes

        if len(self.messages) > max_messages:
            self.messages.pop(0)
            self.messages_dropped += 1

    def get_short_id(self) -> str:
        """Get a shortened version of the connection ID for display.

        Returns:
            First 8 characters of the connection ID.
        """
        return self.connection_id[:8]


[docs] class WebSocketPanel(Panel): """Panel for displaying WebSocket connection and message information.""" panel_id: ClassVar[str] = "WebSocketPanel" title: ClassVar[str] = "WebSocket" template: ClassVar[str] = "panels/websocket.html" has_content: ClassVar[bool] = True nav_title: ClassVar[str] = "WebSocket" _active_connections: ClassVar[dict[str, WebSocketConnection]] = {} _connections_lock: ClassVar[threading.Lock] = threading.Lock() _live_subscribers: ClassVar[set[asyncio.Queue[str]]] = set() _subscribers_lock: ClassVar[threading.Lock] = threading.Lock() __slots__ = ()
[docs] @classmethod def subscribe(cls) -> asyncio.Queue[str]: """Subscribe to live WebSocket panel updates. Returns: An asyncio.Queue that will receive JSON-encoded event messages. """ queue: asyncio.Queue[str] = asyncio.Queue() with cls._subscribers_lock: cls._live_subscribers.add(queue) return queue
[docs] @classmethod def unsubscribe(cls, queue: asyncio.Queue[str]) -> None: """Unsubscribe from live updates. Args: queue: The queue to remove from subscribers. """ with cls._subscribers_lock: cls._live_subscribers.discard(queue)
@classmethod def _broadcast_event(cls, event_type: str, data: dict[str, Any]) -> None: """Broadcast an event to all live subscribers. Args: event_type: Type of event (connection, message, disconnect). data: Event data to broadcast. """ with cls._subscribers_lock: if not cls._live_subscribers: return subscribers = set(cls._live_subscribers) message = json.dumps({"type": event_type, "data": data}) dead_queues = [] for queue in subscribers: try: queue.put_nowait(message) except asyncio.QueueFull: # noqa: PERF203 - queue cleanup in loop is fine dead_queues.append(queue) if dead_queues: with cls._subscribers_lock: for queue in dead_queues: cls._live_subscribers.discard(queue)
[docs] @classmethod def track_connection(cls, connection: WebSocketConnection, ttl: int = 3600, max_connections: int = 50) -> None: """Track a new WebSocket connection. Args: connection: The connection to track. ttl: Time-to-live in seconds for disconnected connections. Defaults to 3600. max_connections: Maximum connections to track. Oldest are dropped when exceeded. """ with cls._connections_lock: cls._active_connections[connection.connection_id] = connection cls._cleanup_old_connections(ttl) cls._enforce_connection_limit(max_connections) cls._broadcast_event( "connection", { "connection_id": connection.connection_id, "short_id": connection.get_short_id(), "path": connection.path, "state": connection.state, "connected_at": connection.connected_at, }, )
[docs] @classmethod def broadcast_message(cls, connection_id: str, message: WebSocketMessage) -> None: """Broadcast a message event to live subscribers. Args: connection_id: The connection ID the message belongs to. message: The WebSocket message that was sent/received. """ cls._broadcast_event( "message", { "connection_id": connection_id, "direction": message.direction, "type": message.message_type, "size": message.size_bytes, "preview": message.get_content_preview(), "timestamp": message.timestamp, }, )
[docs] @classmethod def broadcast_state_change(cls, connection_id: str, state: str, close_code: int | None = None) -> None: """Broadcast a connection state change to live subscribers. Args: connection_id: The connection ID. state: New connection state. close_code: WebSocket close code if applicable. """ cls._broadcast_event( "state_change", { "connection_id": connection_id, "state": state, "close_code": close_code, }, )
[docs] @classmethod def untrack_connection(cls, connection_id: str) -> None: """Remove a connection from tracking. Args: connection_id: The ID of the connection to remove. """ with cls._connections_lock: if connection_id in cls._active_connections: connection = cls._active_connections[connection_id] import time connection.disconnected_at = time.time()
[docs] @classmethod def get_connection(cls, connection_id: str) -> WebSocketConnection | None: """Get a tracked connection by ID. Args: connection_id: The connection ID to look up. Returns: The connection if found, None otherwise. """ with cls._connections_lock: return cls._active_connections.get(connection_id)
@classmethod def _cleanup_old_connections(cls, ttl: int = 3600) -> None: """Remove old disconnected connections based on TTL. Args: ttl: Time-to-live in seconds for disconnected connections. Defaults to 3600. """ import time current_time = time.time() to_remove = [] for conn_id, conn in cls._active_connections.items(): if conn.disconnected_at is not None: age = current_time - conn.disconnected_at if age > ttl: to_remove.append(conn_id) for conn_id in to_remove: del cls._active_connections[conn_id] @classmethod def _enforce_connection_limit(cls, max_connections: int) -> None: """Enforce maximum connection limit by dropping oldest connections. Drops disconnected connections first (oldest first), then active if needed. Args: max_connections: Maximum number of connections to retain. """ if len(cls._active_connections) <= max_connections: return connections = list(cls._active_connections.values()) disconnected = sorted( [c for c in connections if c.disconnected_at is not None], key=lambda c: c.disconnected_at or 0, ) active = sorted( [c for c in connections if c.disconnected_at is None], key=lambda c: c.connected_at, ) to_remove = [] excess = len(cls._active_connections) - max_connections for conn in disconnected: if len(to_remove) >= excess: break to_remove.append(conn.connection_id) for conn in active: if len(to_remove) >= excess: break to_remove.append(conn.connection_id) for conn_id in to_remove: del cls._active_connections[conn_id]
[docs] async def generate_stats(self, context: RequestContext) -> dict[str, Any]: """Generate WebSocket statistics. Args: context: The current request context. Returns: Dictionary containing WebSocket statistics. """ with self._connections_lock: connections = list(self._active_connections.values()) active_connections = [c for c in connections if c.state in ("connecting", "connected")] closed_connections = [c for c in connections if c.state == "closed"] total_messages_sent = sum(c.total_sent for c in connections) total_messages_received = sum(c.total_received for c in connections) total_bytes_sent = sum(c.bytes_sent for c in connections) total_bytes_received = sum(c.bytes_received for c in connections) return { "active_connections": [self._connection_to_dict(c) for c in active_connections], "closed_connections": [self._connection_to_dict(c) for c in closed_connections], "total_active": len(active_connections), "total_closed": len(closed_connections), "total_messages_sent": total_messages_sent, "total_messages_received": total_messages_received, "total_bytes_sent": total_bytes_sent, "total_bytes_received": total_bytes_received, }
@staticmethod def _connection_to_dict(conn: WebSocketConnection) -> dict[str, Any]: """Convert a WebSocketConnection to a dictionary. Args: conn: The connection to convert. Returns: Dictionary representation of the connection. """ return { "connection_id": conn.connection_id, "short_id": conn.get_short_id(), "path": conn.path, "query_string": conn.query_string, "headers": conn.headers, "connected_at": conn.connected_at, "disconnected_at": conn.disconnected_at, "close_code": conn.close_code, "close_reason": conn.close_reason, "state": conn.state, "duration": conn.get_duration(), "total_sent": conn.total_sent, "total_received": conn.total_received, "bytes_sent": conn.bytes_sent, "bytes_received": conn.bytes_received, "message_count": len(conn.messages), "messages_dropped": conn.messages_dropped, "messages": [ { "type": msg.message_type, "direction": msg.direction, "size": msg.size_bytes, "timestamp": msg.timestamp, "preview": msg.get_content_preview(), "content": msg.content if isinstance(msg.content, str) else (msg.get_binary_preview_hex(max_bytes=256) if isinstance(msg.content, bytes) else None), "truncated": msg.truncated, } for msg in conn.messages ], }
[docs] def get_nav_subtitle(self) -> str: """Get the navigation subtitle showing active connection count. Returns: Subtitle string showing active connection count, or empty string. """ with self._connections_lock: active_count = sum(1 for c in self._active_connections.values() if c.state in ("connecting", "connected")) if active_count > 0: return f"{active_count} active" return ""
[docs] @classmethod def get_current_stats(cls) -> dict[str, Any]: """Get current WebSocket statistics without a context. This is used for live updates via WebSocket. Returns: Dictionary containing current WebSocket statistics. """ with cls._connections_lock: connections = list(cls._active_connections.values()) active_connections = [c for c in connections if c.state in ("connecting", "connected")] closed_connections = [c for c in connections if c.state == "closed"] total_messages_sent = sum(c.total_sent for c in connections) total_messages_received = sum(c.total_received for c in connections) total_bytes_sent = sum(c.bytes_sent for c in connections) total_bytes_received = sum(c.bytes_received for c in connections) return { "active_connections": [cls._connection_to_dict(c) for c in active_connections], "closed_connections": [cls._connection_to_dict(c) for c in closed_connections], "total_active": len(active_connections), "total_closed": len(closed_connections), "total_messages_sent": total_messages_sent, "total_messages_received": total_messages_received, "total_bytes_sent": total_bytes_sent, "total_bytes_received": total_bytes_received, }