| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- """
- HTTP metrics tracking middleware.
- Collects RPS, response time, error rate, active requests.
- """
- import time
- from collections import deque
- from datetime import datetime, timezone
- from threading import Lock
- from typing import Deque, Tuple
- from starlette.middleware.base import BaseHTTPMiddleware
- from starlette.requests import Request
- from starlette.responses import Response
- class HTTPMetricsCollector:
- """Thread-safe collector for HTTP metrics."""
- def __init__(self, window_seconds: int = 60):
- self.window_seconds = window_seconds
- self.requests: Deque[Tuple[float, int, float]] = deque() # (timestamp, status_code, duration_ms)
- self.active_requests = 0
- self.lock = Lock()
- def record_request(self, status_code: int, duration_ms: float):
- """Record a completed request."""
- now = time.time()
- with self.lock:
- self.requests.append((now, status_code, duration_ms))
- self._cleanup_old_requests(now)
- def increment_active(self):
- """Increment active request counter."""
- with self.lock:
- self.active_requests += 1
- def decrement_active(self):
- """Decrement active request counter."""
- with self.lock:
- self.active_requests = max(0, self.active_requests - 1)
- def _cleanup_old_requests(self, now: float):
- """Remove requests older than window."""
- cutoff = now - self.window_seconds
- while self.requests and self.requests[0][0] < cutoff:
- self.requests.popleft()
- def get_metrics(self) -> dict:
- """Get current metrics snapshot."""
- now = time.time()
- with self.lock:
- self._cleanup_old_requests(now)
- if not self.requests:
- return {
- "requests_per_sec": 0,
- "avg_response_time_ms": 0,
- "error_rate": 0,
- "active_requests": self.active_requests,
- }
- total_requests = len(self.requests)
- error_requests = sum(1 for _, status, _ in self.requests if status >= 400)
- total_duration = sum(duration for _, _, duration in self.requests)
- return {
- "requests_per_sec": int(total_requests / self.window_seconds),
- "avg_response_time_ms": total_duration / total_requests if total_requests > 0 else 0,
- "error_rate": (error_requests / total_requests * 100) if total_requests > 0 else 0,
- "active_requests": self.active_requests,
- }
- # Global collector instance
- http_metrics_collector = HTTPMetricsCollector(window_seconds=60)
- class HTTPMetricsMiddleware(BaseHTTPMiddleware):
- """Middleware to track HTTP request metrics."""
- async def dispatch(self, request: Request, call_next):
- # Skip metrics endpoints to avoid recursive counting
- if request.url.path.startswith("/api/v1/superadmin/monitoring"):
- return await call_next(request)
- http_metrics_collector.increment_active()
- start_time = time.time()
- try:
- response: Response = await call_next(request)
- duration_ms = (time.time() - start_time) * 1000
- http_metrics_collector.record_request(response.status_code, duration_ms)
- return response
- except Exception as e:
- duration_ms = (time.time() - start_time) * 1000
- http_metrics_collector.record_request(500, duration_ms)
- raise
- finally:
- http_metrics_collector.decrement_active()
|