""" HTTP metrics tracking middleware. Collects RPS, response time, error rate, active requests. """ import time from collections import deque from datetime import datetime, timezone from threading import Lock from typing import Deque, Tuple from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response class HTTPMetricsCollector: """Thread-safe collector for HTTP metrics.""" def __init__(self, window_seconds: int = 60): self.window_seconds = window_seconds self.requests: Deque[Tuple[float, int, float]] = deque() # (timestamp, status_code, duration_ms) self.active_requests = 0 self.lock = Lock() def record_request(self, status_code: int, duration_ms: float): """Record a completed request.""" now = time.time() with self.lock: self.requests.append((now, status_code, duration_ms)) self._cleanup_old_requests(now) def increment_active(self): """Increment active request counter.""" with self.lock: self.active_requests += 1 def decrement_active(self): """Decrement active request counter.""" with self.lock: self.active_requests = max(0, self.active_requests - 1) def _cleanup_old_requests(self, now: float): """Remove requests older than window.""" cutoff = now - self.window_seconds while self.requests and self.requests[0][0] < cutoff: self.requests.popleft() def get_metrics(self) -> dict: """Get current metrics snapshot.""" now = time.time() with self.lock: self._cleanup_old_requests(now) if not self.requests: return { "requests_per_sec": 0, "avg_response_time_ms": 0, "error_rate": 0, "active_requests": self.active_requests, } total_requests = len(self.requests) error_requests = sum(1 for _, status, _ in self.requests if status >= 400) total_duration = sum(duration for _, _, duration in self.requests) return { "requests_per_sec": int(total_requests / self.window_seconds), "avg_response_time_ms": total_duration / total_requests if total_requests > 0 else 0, "error_rate": (error_requests / total_requests * 100) if total_requests > 0 else 0, "active_requests": self.active_requests, } # Global collector instance http_metrics_collector = HTTPMetricsCollector(window_seconds=60) class HTTPMetricsMiddleware(BaseHTTPMiddleware): """Middleware to track HTTP request metrics.""" async def dispatch(self, request: Request, call_next): # Skip metrics endpoints to avoid recursive counting if request.url.path.startswith("/api/v1/superadmin/monitoring"): return await call_next(request) http_metrics_collector.increment_active() start_time = time.time() try: response: Response = await call_next(request) duration_ms = (time.time() - start_time) * 1000 http_metrics_collector.record_request(response.status_code, duration_ms) return response except Exception as e: duration_ms = (time.time() - start_time) * 1000 http_metrics_collector.record_request(500, duration_ms) raise finally: http_metrics_collector.decrement_active()