"""Cache panel for tracking cache operations during requests."""
from __future__ import annotations
import threading
import time
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from debug_toolbar.core.panel import Panel
if TYPE_CHECKING:
from collections.abc import Generator
from debug_toolbar.core.context import RequestContext
from debug_toolbar.core.toolbar import DebugToolbar
CacheOperation = Literal["GET", "SET", "DELETE", "INCR", "DECR", "MGET", "MSET", "EXISTS", "EXPIRE", "OTHER"]
@dataclass
class CacheOperationRecord:
"""Record of a single cache operation."""
operation: CacheOperation
key: str | list[str]
hit: bool | None
duration: float
timestamp: float
backend: str
extra: dict[str, Any] = field(default_factory=dict)
_patch_lock = threading.Lock()
class CacheTracker:
"""Tracks cache operations for Redis and memcached."""
def __init__(self) -> None:
self.operations: list[CacheOperationRecord] = []
self._original_redis_methods: dict[str, Any] = {}
self._original_memcache_methods: dict[str, Any] = {}
self._tracking_enabled = False
def start_tracking(self) -> None:
"""Start tracking cache operations by patching client methods."""
if self._tracking_enabled:
return
self._tracking_enabled = True
with _patch_lock:
self._patch_redis()
self._patch_memcache()
def stop_tracking(self) -> None:
"""Stop tracking and restore original methods."""
if not self._tracking_enabled:
return
with _patch_lock:
self._unpatch_redis()
self._unpatch_memcache()
self._tracking_enabled = False
def clear(self) -> None:
"""Clear tracked operations."""
self.operations = []
def _patch_redis(self) -> None:
"""Patch Redis client methods to track operations."""
try:
import redis # type: ignore[import-untyped]
except ImportError:
return
if hasattr(redis.Redis, "_debug_toolbar_patched"):
return
methods_to_patch = {
"get": ("GET", True),
"set": ("SET", False),
"delete": ("DELETE", False),
"mget": ("MGET", True),
"mset": ("MSET", False),
"incr": ("INCR", False),
"decr": ("DECR", False),
"exists": ("EXISTS", True),
"expire": ("EXPIRE", False),
"setex": ("SET", False),
"setnx": ("SET", False),
"getset": ("GET", True),
"hget": ("GET", True),
"hset": ("SET", False),
"hdel": ("DELETE", False),
"sadd": ("SET", False),
"srem": ("DELETE", False),
"lpush": ("SET", False),
"rpush": ("SET", False),
"lpop": ("GET", True),
"rpop": ("GET", True),
}
for method_name, (operation, is_read) in methods_to_patch.items():
original_method = getattr(redis.Redis, method_name, None)
if original_method is None:
continue
self._original_redis_methods[method_name] = original_method
def create_wrapper(
orig_method: Any,
op: CacheOperation,
check_hit: bool, # noqa: FBT001
) -> Any:
def wrapper(self_redis: Any, *args: Any, **kwargs: Any) -> Any:
start = time.perf_counter()
result = orig_method(self_redis, *args, **kwargs)
duration = time.perf_counter() - start
key = args[0] if args else kwargs.get("name", "unknown")
hit = None
if check_hit:
hit = result is not None
tracker = _get_tracker()
if tracker:
tracker._record_operation( # noqa: SLF001
operation=op,
key=key,
hit=hit,
duration=duration,
backend="redis",
)
return result
return wrapper
setattr(
redis.Redis,
method_name,
create_wrapper(original_method, operation, is_read), # type: ignore[arg-type]
)
redis.Redis._debug_toolbar_patched = True # type: ignore[attr-defined] # noqa: SLF001
def _unpatch_redis(self) -> None:
"""Restore original Redis methods."""
try:
import redis # type: ignore[import-untyped]
except ImportError:
return
if not hasattr(redis.Redis, "_debug_toolbar_patched"):
return
for method_name, original_method in self._original_redis_methods.items():
setattr(redis.Redis, method_name, original_method)
delattr(redis.Redis, "_debug_toolbar_patched")
self._original_redis_methods.clear()
def _patch_memcache(self) -> None:
"""Patch pymemcache client methods to track operations."""
try:
from pymemcache.client.base import Client # type: ignore[import-untyped]
except ImportError:
return
if hasattr(Client, "_debug_toolbar_patched"):
return
methods_to_patch = {
"get": ("GET", True),
"set": ("SET", False),
"delete": ("DELETE", False),
"get_multi": ("MGET", True),
"set_multi": ("MSET", False),
"delete_multi": ("DELETE", False),
"incr": ("INCR", False),
"decr": ("DECR", False),
"add": ("SET", False),
"replace": ("SET", False),
"append": ("SET", False),
"prepend": ("SET", False),
}
for method_name, (operation, is_read) in methods_to_patch.items():
original_method = getattr(Client, method_name, None)
if original_method is None:
continue
self._original_memcache_methods[method_name] = original_method
def create_wrapper(
orig_method: Any,
op: CacheOperation,
check_hit: bool, # noqa: FBT001
) -> Any:
def wrapper(self_client: Any, *args: Any, **kwargs: Any) -> Any:
start = time.perf_counter()
result = orig_method(self_client, *args, **kwargs)
duration = time.perf_counter() - start
key = args[0] if args else "unknown"
hit = None
if check_hit:
if isinstance(result, dict):
hit = len(result) > 0
else:
hit = result is not None
tracker = _get_tracker()
if tracker:
tracker._record_operation( # noqa: SLF001
operation=op,
key=key,
hit=hit,
duration=duration,
backend="memcached",
)
return result
return wrapper
setattr(
Client,
method_name,
create_wrapper(original_method, operation, is_read), # type: ignore[arg-type]
)
Client._debug_toolbar_patched = True # type: ignore[attr-defined] # noqa: SLF001
def _unpatch_memcache(self) -> None:
"""Restore original pymemcache methods."""
try:
from pymemcache.client.base import Client # type: ignore[import-untyped]
except ImportError:
return
if not hasattr(Client, "_debug_toolbar_patched"):
return
for method_name, original_method in self._original_memcache_methods.items():
setattr(Client, method_name, original_method)
delattr(Client, "_debug_toolbar_patched")
self._original_memcache_methods.clear()
def _record_operation(
self,
operation: CacheOperation,
key: str | list[str],
hit: bool | None, # noqa: FBT001
duration: float,
backend: str,
extra: dict[str, Any] | None = None,
) -> None:
"""Record a cache operation."""
self.operations.append(
CacheOperationRecord(
operation=operation,
key=key,
hit=hit,
duration=duration,
timestamp=time.time(),
backend=backend,
extra=extra or {},
)
)
@contextmanager
def track_operation(
self,
operation: CacheOperation,
key: str | list[str],
backend: str,
) -> Generator[dict[str, Any], None, None]:
"""Context manager for tracking custom cache operations."""
start = time.perf_counter()
extra: dict[str, Any] = {}
yield extra
duration = time.perf_counter() - start
hit = extra.get("hit")
self._record_operation(
operation=operation,
key=key,
hit=hit,
duration=duration,
backend=backend,
extra=extra,
)
_active_tracker: ContextVar[CacheTracker | None] = ContextVar("_active_tracker", default=None)
def _get_tracker() -> CacheTracker | None:
"""Get the currently active cache tracker."""
return _active_tracker.get()
def _set_tracker(tracker: CacheTracker | None) -> None:
"""Set the active cache tracker."""
_active_tracker.set(tracker)
[docs]
class CachePanel(Panel):
"""Panel displaying cache operations during the request.
Tracks:
- Cache operations (GET, SET, DELETE, etc.)
- Hit/miss status
- Operation duration
- Backend type (Redis, memcached)
- Aggregate statistics
"""
panel_id: ClassVar[str] = "CachePanel"
title: ClassVar[str] = "Cache"
template: ClassVar[str] = "panels/cache.html"
has_content: ClassVar[bool] = True
nav_title: ClassVar[str] = "Cache"
__slots__ = ("_tracker",)
[docs]
def __init__(self, toolbar: DebugToolbar) -> None:
super().__init__(toolbar)
self._tracker = CacheTracker()
[docs]
async def process_request(self, context: RequestContext) -> None:
"""Start tracking cache operations."""
self._tracker.clear()
_set_tracker(self._tracker)
self._tracker.start_tracking()
[docs]
async def process_response(self, context: RequestContext) -> None:
"""Stop tracking cache operations."""
self._tracker.stop_tracking()
_set_tracker(None)
[docs]
async def generate_stats(self, context: RequestContext) -> dict[str, Any]:
"""Generate cache statistics."""
operations = self._tracker.operations
total_operations = len(operations)
hits = sum(1 for op in operations if op.hit is True)
misses = sum(1 for op in operations if op.hit is False)
total_time = sum(op.duration for op in operations)
hit_rate = (hits / (hits + misses) * 100) if (hits + misses) > 0 else 0.0
avg_time = total_time / total_operations if total_operations > 0 else 0.0
backends = sorted({op.backend for op in operations})
operation_list = [
{
"operation": op.operation,
"key": op.key if isinstance(op.key, str) else ",".join(op.key),
"hit": op.hit,
"duration": op.duration,
"duration_ms": op.duration * 1000,
"timestamp": op.timestamp,
"backend": op.backend,
"extra": op.extra,
}
for op in operations
]
by_operation: dict[str, int] = {}
by_backend: dict[str, int] = {}
for op in operations:
by_operation[op.operation] = by_operation.get(op.operation, 0) + 1
by_backend[op.backend] = by_backend.get(op.backend, 0) + 1
stats = {
"operations": operation_list,
"total_operations": total_operations,
"hits": hits,
"misses": misses,
"hit_rate": hit_rate,
"total_time": total_time,
"avg_time": avg_time,
"backends": backends,
"by_operation": by_operation,
"by_backend": by_backend,
}
if total_time > 0:
context.record_timing("cache_time", total_time)
return stats
[docs]
def generate_server_timing(self, context: RequestContext) -> dict[str, float]:
"""Generate Server-Timing data for cache operations."""
stats = self.get_stats(context)
if not stats:
return {}
timing: dict[str, float] = {}
total_time = stats.get("total_time", 0)
if total_time > 0:
timing["cache"] = total_time
by_backend = stats.get("by_backend", {})
operations = stats.get("operations", [])
for backend in by_backend:
backend_time = sum(op["duration"] for op in operations if op["backend"] == backend)
if backend_time > 0:
timing[f"cache-{backend}"] = backend_time
return timing
[docs]
def get_nav_subtitle(self) -> str:
"""Get the navigation subtitle showing cache stats."""
return ""