Source code for debug_toolbar.litestar.middleware

"""Debug toolbar middleware for Litestar."""

from __future__ import annotations

import logging
import re
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4

from debug_toolbar.core import DebugToolbar, RequestContext, set_request_context
from debug_toolbar.core.panels.websocket import WebSocketConnection, WebSocketMessage, WebSocketPanel
from debug_toolbar.litestar.config import LitestarDebugToolbarConfig
from debug_toolbar.litestar.panels.events import collect_events_metadata
from litestar.middleware import AbstractMiddleware

if TYPE_CHECKING:
    from litestar.types import (
        ASGIApp,
        HTTPResponseBodyEvent,
        HTTPResponseStartEvent,
        Message,
        Receive,
        Scope,
        Send,
    )

    from litestar import Request

logger = logging.getLogger(__name__)


[docs] @dataclass class ResponseState: """Tracks response state during middleware processing.""" started: bool = False body_chunks: list[bytes] = field(default_factory=list) headers: dict[str, str] = field(default_factory=dict) status_code: int = 200 is_html: bool = False headers_sent: bool = False original_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
[docs] class DebugToolbarMiddleware(AbstractMiddleware): """Litestar middleware for the debug toolbar. This middleware: - Initializes the request context for each request - Collects request/response metadata - Injects the toolbar HTML into responses - Adds Server-Timing headers """ scopes = {"http", "websocket"} exclude = ["_debug_toolbar"]
[docs] def __init__( self, app: ASGIApp, config: LitestarDebugToolbarConfig | None = None, toolbar: DebugToolbar | None = None, ) -> None: """Initialize the middleware. Args: app: The next ASGI application. config: Toolbar configuration. Uses defaults if not provided. toolbar: Optional shared toolbar instance. Creates new if not provided. """ super().__init__(app) self.config = config or LitestarDebugToolbarConfig() self.toolbar = toolbar or DebugToolbar(self.config)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Process an ASGI request.""" path = scope.get("path", "/") if scope["type"] == "websocket": if any(path.startswith(excluded) for excluded in self.config.exclude_paths): await self.app(scope, receive, send) return await self._handle_websocket(scope, receive, send) return if scope["type"] != "http": await self.app(scope, receive, send) return from litestar import Request request = Request(scope) if not self.config.should_show_toolbar(request): await self.app(scope, receive, send) return context = await self.toolbar.process_request() scope["_debug_toolbar_context"] = context # type: ignore[typeddict-unknown-key] self._populate_request_metadata(request, context) self._populate_events_metadata(request, context) state = ResponseState() send_wrapper = self._create_send_wrapper(send, context, state) try: await self.app(scope, receive, send_wrapper) except Exception: await self._handle_exception(send, state) raise finally: set_request_context(None) def _create_send_wrapper(self, send: Send, context: RequestContext, state: ResponseState) -> Send: """Create a send wrapper that intercepts and modifies responses.""" async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": await self._handle_response_start(send, message, context, state) elif message["type"] == "http.response.body": await self._handle_response_body(send, message, context, state) else: await send(message) return send_wrapper async def _handle_response_start( self, send: Send, message: Message, context: RequestContext, state: ResponseState, ) -> None: """Handle http.response.start message.""" state.started = True start_msg = cast("HTTPResponseStartEvent", message) state.status_code = start_msg["status"] state.original_headers = list(start_msg.get("headers", [])) headers = dict(state.original_headers) state.headers = { k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v for k, v in headers.items() } context.metadata["status_code"] = state.status_code context.metadata["response_headers"] = state.headers context.metadata["response_content_type"] = state.headers.get("content-type", "") state.is_html = "text/html" in state.headers.get("content-type", "") if not state.is_html: await self._send_non_html_start(send, context, state) async def _send_non_html_start(self, send: Send, context: RequestContext, state: ResponseState) -> None: """Send response start for non-HTML responses.""" await self.toolbar.process_response(context) server_timing = self.toolbar.get_server_timing_header(context) new_headers = list(state.original_headers) if server_timing: new_headers.append((b"server-timing", server_timing.encode())) modified_msg: HTTPResponseStartEvent = { "type": "http.response.start", "status": state.status_code, "headers": new_headers, } await send(modified_msg) state.headers_sent = True async def _handle_response_body( self, send: Send, message: Message, context: RequestContext, state: ResponseState, ) -> None: """Handle http.response.body message.""" body_msg = cast("HTTPResponseBodyEvent", message) body = body_msg.get("body", b"") if not state.is_html: await send(message) return state.body_chunks.append(body) if not body_msg.get("more_body", False): await self._send_html_response(send, context, state) async def _send_html_response(self, send: Send, context: RequestContext, state: ResponseState) -> None: """Process and send buffered HTML response with toolbar injection.""" full_body = b"".join(state.body_chunks) try: await self.toolbar.process_response(context) modified_body = self._inject_toolbar(full_body, context) server_timing = self.toolbar.get_server_timing_header(context) except Exception: logger.debug("Toolbar processing failed, sending original response", exc_info=True) modified_body = full_body server_timing = None new_headers: list[tuple[bytes, bytes]] = [ (k.encode(), v.encode()) for k, v in state.headers.items() if k.lower() != "content-length" ] new_headers.append((b"content-length", str(len(modified_body)).encode())) if server_timing: new_headers.append((b"server-timing", server_timing.encode())) start_event: HTTPResponseStartEvent = { "type": "http.response.start", "status": state.status_code, "headers": new_headers, } await send(start_event) body_event: HTTPResponseBodyEvent = { "type": "http.response.body", "body": modified_body, "more_body": False, } await send(body_event) state.headers_sent = True async def _handle_exception(self, send: Send, state: ResponseState) -> None: """Handle exception during response processing.""" if not (state.started and state.is_html and not state.headers_sent): return try: start_event: HTTPResponseStartEvent = { "type": "http.response.start", "status": state.status_code, "headers": state.original_headers, } await send(start_event) body_event: HTTPResponseBodyEvent = { "type": "http.response.body", "body": b"".join(state.body_chunks), "more_body": False, } await send(body_event) except Exception: logger.debug("Failed to send buffered response during exception handling", exc_info=True) async def _handle_websocket(self, scope: Scope, receive: Receive, send: Send) -> None: """Handle WebSocket connection tracking. Args: scope: ASGI scope for the WebSocket connection. receive: ASGI receive callable. send: ASGI send callable. """ if not self.config.websocket_tracking_enabled: await self.app(scope, receive, send) return connection = WebSocketConnection( connection_id=str(uuid4()), path=scope.get("path", "/"), query_string=scope.get("query_string", b"").decode("utf-8", errors="replace"), headers={ k.decode("utf-8", errors="replace"): v.decode("utf-8", errors="replace") for k, v in scope.get("headers", []) }, connected_at=time.time(), state="connecting", ) WebSocketPanel.track_connection( connection, ttl=self.config.websocket_connection_ttl, max_connections=self.config.websocket_max_connections, ) logger.debug("WebSocket tracked: %s at %s", connection.connection_id[:8], connection.path) send_wrapper = self._create_websocket_send_wrapper(send, connection) receive_wrapper = self._create_websocket_receive_wrapper(receive, connection) try: await self.app(scope, receive_wrapper, send_wrapper) finally: if connection.state not in ("closing", "closed"): connection.state = "closed" connection.disconnected_at = time.time() def _create_websocket_message( self, direction: Literal["sent", "received"], message_data: str | bytes, ) -> WebSocketMessage: """Create a WebSocketMessage from message data. Args: direction: Whether the message was sent or received. message_data: The message content (text or bytes). Returns: A WebSocketMessage instance with proper type and truncation handling. """ if isinstance(message_data, str): message_type: Literal["text", "binary"] = "text" content: str | bytes = message_data size_bytes = len(message_data.encode("utf-8")) else: message_type = "binary" content = message_data size_bytes = len(message_data) truncated = size_bytes > self.config.websocket_max_message_size if truncated: content = content[: self.config.websocket_max_message_size] return WebSocketMessage( direction=direction, message_type=message_type, content=content, timestamp=time.time(), size_bytes=size_bytes, truncated=truncated, ) def _create_websocket_send_wrapper(self, send: Send, connection: WebSocketConnection) -> Send: """Create a send wrapper for WebSocket messages. Args: send: The original ASGI send callable. connection: The WebSocket connection being tracked. Returns: A wrapped send callable that tracks WebSocket messages. """ async def send_wrapper(message: Message) -> None: try: msg_type = message.get("type", "") if msg_type == "websocket.accept": connection.state = "connected" WebSocketPanel.broadcast_state_change(connection.connection_id, "connected") elif msg_type == "websocket.send": message_data = message.get("text") or message.get("bytes") if message_data is not None: ws_message = self._create_websocket_message("sent", message_data) connection.add_message( ws_message, max_messages=self.config.websocket_max_messages_per_connection ) WebSocketPanel.broadcast_message(connection.connection_id, ws_message) elif msg_type == "websocket.close": connection.state = "closing" connection.close_code = message.get("code") connection.close_reason = message.get("reason", "") WebSocketPanel.broadcast_state_change(connection.connection_id, "closing", connection.close_code) except Exception: logger.debug("Error tracking WebSocket send message", exc_info=True) await send(message) return send_wrapper def _create_websocket_receive_wrapper(self, receive: Receive, connection: WebSocketConnection) -> Receive: """Create a receive wrapper for WebSocket messages. Args: receive: The original ASGI receive callable. connection: The WebSocket connection being tracked. Returns: A wrapped receive callable that tracks WebSocket messages. """ async def receive_wrapper() -> Any: message = await receive() try: msg_type = message.get("type", "") if msg_type == "websocket.receive": message_data = message.get("text") or message.get("bytes") if message_data is not None: ws_message = self._create_websocket_message("received", message_data) connection.add_message( ws_message, max_messages=self.config.websocket_max_messages_per_connection ) WebSocketPanel.broadcast_message(connection.connection_id, ws_message) elif msg_type == "websocket.disconnect": connection.state = "closed" connection.close_code = message.get("code") connection.disconnected_at = time.time() WebSocketPanel.broadcast_state_change(connection.connection_id, "closed", connection.close_code) except Exception: logger.debug("Error tracking WebSocket receive message", exc_info=True) return message return receive_wrapper def _populate_request_metadata(self, request: Request, context: RequestContext) -> None: """Populate request metadata in the context. Args: request: The Litestar request. context: The request context to populate. """ context.metadata["method"] = request.method context.metadata["path"] = request.url.path context.metadata["query_string"] = request.url.query context.metadata["query_params"] = dict(request.query_params) context.metadata["headers"] = dict(request.headers) context.metadata["cookies"] = dict(request.cookies) context.metadata["content_type"] = request.content_type[0] if request.content_type else "" context.metadata["scheme"] = request.url.scheme if request.client: context.metadata["client_host"] = request.client.host context.metadata["client_port"] = request.client.port self._populate_routes_metadata(request, context) def _populate_events_metadata(self, request: Request, context: RequestContext) -> None: """Populate events/lifecycle metadata from the Litestar app. Args: request: The Litestar request. context: The request context to populate. """ try: collect_events_metadata(request.app, context) except Exception: context.metadata["events"] = { "lifecycle_hooks": {}, "request_hooks": {}, "exception_handlers": [], "executed_hooks": [], } def _populate_routes_metadata(self, request: Request, context: RequestContext) -> None: """Populate route information from the Litestar app. Args: request: The Litestar request. context: The request context to populate. """ try: app = request.app routes_info = [] for route in app.routes: route_data = { "path": route.path, "methods": sorted(getattr(route, "methods", [])), "name": getattr(route, "name", None), } handler = getattr(route, "route_handler", None) if handler: route_data["handler"] = getattr(handler, "fn", handler).__name__ route_data["tags"] = list(getattr(handler, "tags", [])) routes_info.append(route_data) context.metadata["routes"] = routes_info scope = request.scope route_handler = scope.get("route_handler") if route_handler: context.metadata["matched_route"] = getattr(route_handler, "path", request.url.path) except Exception: context.metadata["routes"] = [] context.metadata["matched_route"] = "" def _inject_toolbar(self, body: bytes, context: RequestContext) -> bytes: """Inject the toolbar HTML into the response body. Args: body: The original response body. context: The request context with collected data. Returns: The modified response body with toolbar injected. """ try: html = body.decode("utf-8") except UnicodeDecodeError: return body toolbar_data = self.toolbar.get_toolbar_data(context) toolbar_html = self._render_toolbar(toolbar_data) insert_before = self.config.insert_before if insert_before in html: html = html.replace(insert_before, toolbar_html + insert_before) else: pattern = re.compile(re.escape(insert_before), re.IGNORECASE) html = pattern.sub(toolbar_html + insert_before, html, count=1) return html.encode("utf-8") def _render_toolbar(self, data: dict[str, Any]) -> str: """Render the toolbar HTML. Args: data: Toolbar data from get_toolbar_data(). Returns: HTML string for the toolbar. """ panels_html = [] for panel in data.get("panels", []): subtitle = panel.get("nav_subtitle", "") subtitle_html = f'<span class="panel-subtitle">{subtitle}</span>' if subtitle else "" panels_html.append(f""" <button class="toolbar-panel-btn" data-panel-id="{panel["panel_id"]}"> <span class="panel-title">{panel["nav_title"]}</span> {subtitle_html} </button> """) timing = data.get("timing", {}) total_time = timing.get("total_time", 0) * 1000 request_id = data.get("request_id", "N/A") return f""" <link rel="stylesheet" href="/_debug_toolbar/static/toolbar.css"> <div id="debug-toolbar" data-request-id="{request_id}"> <div class="toolbar-bar"> <span class="toolbar-brand" title="Click to toggle">Debug Toolbar</span> <span class="toolbar-time">{total_time:.2f}ms</span> <div class="toolbar-panels"> {"".join(panels_html)} </div> <span class="toolbar-request-id"> <a href="/_debug_toolbar/{request_id}" class="toolbar-history-link" title="View request details">{request_id[:8]}</a> </span> <a href="/_debug_toolbar/" class="toolbar-history-link" title="View request history">History</a> </div> <div class="toolbar-details"></div> </div> <script src="/_debug_toolbar/static/toolbar.js"></script> """