From e7659d2d0e32b14362eb62f66a416c6b884174e4 Mon Sep 17 00:00:00 2001 From: Shaishav Pidadi Date: Fri, 24 Apr 2026 19:30:14 -0400 Subject: [PATCH 1/2] feat(rate-limit): minute-bucket middleware with per-key and per-org counters (1.5c) Replaces the sliding-window-log limiter with minute-bucket sliding-window counters for four dimensions: per-key requests, per-key tokens, per-org requests, per-org tokens. Counters live under `{dim}:{scope}:{id}:{minute}` keys with a 2-minute TTL so the previous bucket contributes to the sliding-window weight. The limiter now runs as FastAPI middleware (`app.rate_limit_middleware`) before route handlers, so unauthenticated flood attempts cannot escape the counter by bailing in `require_api_key`. All responses carry `X-RateLimit-Limit`, `X-RateLimit-Remaining`, and `X-RateLimit-Reset` reflecting the most restrictive dimension; denied requests return 429 with `Retry-After`. Cipher review scope (precheck#31, non-blocking #4): - `REDIS_URL` posture validator: non-debug environments must use `rediss://` (TLS) and carry a password. Plaintext/passwordless Redis is debug-only. - Multi-replica quota-bypass: resolved by defaulting to `fail-closed` on Redis outage. The middleware returns 503 `rate limiter unavailable` rather than silently falling back to a per-replica in-memory counter that multiplies the effective quota by N. Operators can opt into `RATE_LIMIT_FAIL_MODE=open` (accept quota bypass) or `local` (debug-only per-replica fallback); `Settings` rejects `local` outside `DEBUG=true`. Behavior is documented in `precheck/PROJECT_SPECS.md#rate-limiting`. Refs: 62aac781-1312-4184-b5e3-39ce37afddb7 --- PROJECT_SPECS.md | 53 +++- app/api.py | 39 +-- app/main.py | 6 + app/rate_limit.py | 518 ++++++++++++++++++++++++---------- app/rate_limit_middleware.py | 151 ++++++++++ app/settings.py | 64 ++++- tests/test_rate_limit.py | 247 +++++++++++++--- tests/test_rate_limit_http.py | 117 ++++++-- tests/test_settings.py | 91 ++++++ 9 files changed, 1026 insertions(+), 260 deletions(-) create mode 100644 app/rate_limit_middleware.py diff --git a/PROJECT_SPECS.md b/PROJECT_SPECS.md index 8c3bb7d..1966084 100644 --- a/PROJECT_SPECS.md +++ b/PROJECT_SPECS.md @@ -948,9 +948,56 @@ Budget limits can be configured per user: - API key extracted from `X-Governs-Key` header and forwarded to webhook events ### Rate Limiting -- 100 requests per minute per user -- Configurable limits and windows -- Redis-based rate limiting (optional) + +Minute-bucket sliding-window counters, enforced by the +`app.rate_limit_middleware` FastAPI middleware before any route handler runs. +Four dimensions are evaluated per authenticated request: + +| Counter key | Default limit | +|------------------------------------|----------------------| +| `req:key:{key_hash}:{minute}` | 100 req/min | +| `tokens:key:{key_hash}:{minute}` | 100,000 tokens/min | +| `req:org:{org_id}:{minute}` | 1,000 req/min | +| `tokens:org:{org_id}:{minute}` | 1,000,000 tokens/min | + +Token cost is estimated from the request `Content-Length` as `ceil(bytes / 4)` +(standard rough heuristic) until §1.5d wires policy-driven limits and real +tokenizer counts. + +All responses carry `X-RateLimit-Limit`, `X-RateLimit-Remaining`, and +`X-RateLimit-Reset` reflecting the most restrictive dimension. Denied +requests return HTTP 429 with a `Retry-After` header in seconds. + +Unauthenticated paths (`/api/v1/health`, `/api/v1/ready`, `/api/metrics`, +`/docs`, `/redoc`, `/openapi.json`, `/`) skip the limiter so probes cannot +consume quota. + +#### Redis posture + +`REDIS_URL` **must** use the `rediss://` TLS scheme and carry a password in +any non-debug environment. The `Settings` validator rejects plaintext or +passwordless URLs — this protects counters against on-path tampering and +co-tenant reads. Plaintext `redis://` is accepted only when `DEBUG=true`. + +#### Redis outage behavior (`RATE_LIMIT_FAIL_MODE`) + +When Redis is configured but unreachable at request time the limiter +evaluates `RATE_LIMIT_FAIL_MODE`: + +* `closed` — default. The middleware returns HTTP 503 + `rate limiter unavailable`. Safe under multi-replica deployments. +* `open` — requests are allowed without a counter check. Operators must + explicitly accept the quota-bypass risk. +* `local` — per-replica in-memory fallback. Across N replicas this + multiplies the effective quota by N, so `Settings` rejects it outside + debug mode. Intended for single-replica dev. + +When `REDIS_URL` is unset entirely (dev/tests), the limiter runs purely +against in-memory buckets regardless of `RATE_LIMIT_FAIL_MODE`. + +Rationale for the fail-closed default comes from Cipher's review on +precheck#31: a silent in-memory fallback on production replicas turns the +rate limit into a denial-of-quota *ceiling* rather than a *floor*. ### PII Protection - Multiple redaction strategies diff --git a/app/api.py b/app/api.py index 4a8a551..3d8ef47 100644 --- a/app/api.py +++ b/app/api.py @@ -27,15 +27,12 @@ ) from .models import DecisionResponse, PrePostCheckRequest from .policies import evaluate, evaluate_with_payload_policy -from .rate_limit import rate_limiter from .settings import settings from .storage import APIKey, get_db logger = logging.getLogger(__name__) router = APIRouter() -RATE_LIMIT_REQUESTS = 100 -RATE_LIMIT_WINDOW_SECONDS = 60 def _ensure_correlation_id(corr_id: Optional[str]) -> str: @@ -306,22 +303,8 @@ async def precheck( user_id = req.user_id correlation_id = _ensure_correlation_id(req.corr_id) - # Rate limiting (100 requests per minute per user/api_key) - if user_id: - rate_limit_key = f"precheck:{user_id}" - else: - rate_limit_key = f"precheck:key:{api_key}" - if not rate_limiter.is_allowed( - rate_limit_key, limit=RATE_LIMIT_REQUESTS, window=RATE_LIMIT_WINDOW_SECONDS - ): - retry_after = rate_limiter.retry_after( - rate_limit_key, limit=RATE_LIMIT_REQUESTS, window=RATE_LIMIT_WINDOW_SECONDS - ) - raise HTTPException( - status_code=429, - detail="rate limit exceeded", - headers={"Retry-After": str(max(1, retry_after))}, - ) + # Rate limiting is enforced by app.rate_limit_middleware before this + # handler runs — see app/rate_limit_middleware.py. # Metrics: Track active requests set_active_requests("precheck", 1) @@ -485,22 +468,8 @@ async def postcheck( user_id = req.user_id correlation_id = _ensure_correlation_id(req.corr_id) - # Rate limiting (100 requests per minute per user/api_key) - if user_id: - rate_limit_key = f"postcheck:{user_id}" - else: - rate_limit_key = f"postcheck:key:{api_key}" - if not rate_limiter.is_allowed( - rate_limit_key, limit=RATE_LIMIT_REQUESTS, window=RATE_LIMIT_WINDOW_SECONDS - ): - retry_after = rate_limiter.retry_after( - rate_limit_key, limit=RATE_LIMIT_REQUESTS, window=RATE_LIMIT_WINDOW_SECONDS - ) - raise HTTPException( - status_code=429, - detail="rate limit exceeded", - headers={"Retry-After": str(max(1, retry_after))}, - ) + # Rate limiting is enforced by app.rate_limit_middleware before this + # handler runs — see app/rate_limit_middleware.py. # Metrics: Track active requests set_active_requests("postcheck", 1) diff --git a/app/main.py b/app/main.py index 9b1e408..a58a8ad 100644 --- a/app/main.py +++ b/app/main.py @@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse from .api import router +from .rate_limit_middleware import install_rate_limit_middleware from .settings import settings from .storage import create_tables @@ -50,6 +51,11 @@ def create_app() -> FastAPI: lifespan=lifespan, ) + # Middleware registration order is inside-out: the LAST decorator runs + # OUTERMOST. Install rate limiting first so request_id and response_time + # still apply to 429 / 503 responses. + install_rate_limit_middleware(app) + @app.middleware("http") async def request_id_middleware(request: Request, call_next): request_id = str(uuid.uuid4()) diff --git a/app/rate_limit.py b/app/rate_limit.py index 0127648..633263e 100644 --- a/app/rate_limit.py +++ b/app/rate_limit.py @@ -1,9 +1,32 @@ +"""Minute-bucket sliding-window rate limiter. + +Counter shape (per §1.5c): + - ``req:key:{key_id}:{minute_bucket}`` + - ``tokens:key:{key_id}:{minute_bucket}`` + - ``req:org:{org_id}:{minute_bucket}`` + - ``tokens:org:{org_id}:{minute_bucket}`` + +Each minute bucket is an atomic Redis counter with a two-minute TTL so the +previous bucket is still visible for the sliding-window weight. + +Sliding-window weight (Cloudflare-style): + + weighted = prev_count * (1 - elapsed_in_current / 60) + current_count + +Request is denied when ``weighted + cost > limit`` for any dimension. + +Redis-outage behavior is controlled by ``fail_mode`` (Cipher review on +precheck#31). See ``RateLimiter.__init__`` for semantics. +""" + +from __future__ import annotations + import logging import math import threading import time -from collections import deque -from typing import Deque, Dict, Optional +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Sequence, Tuple from .settings import settings @@ -11,185 +34,382 @@ try: import redis -except Exception: # pragma: no cover - exercised in environments without redis package +except Exception: # pragma: no cover redis = None -class RateLimiter: - """Redis-first sliding-window rate limiter with in-memory fallback.""" +WINDOW_SECONDS = 60 +_BUCKET_TTL_SECONDS = WINDOW_SECONDS * 2 + + +@dataclass(frozen=True) +class LimitSpec: + """One rate-limit dimension to evaluate on this request.""" + + name: str # e.g. "req-key", "req-org", "tokens-key", "tokens-org" + key: str # Redis key prefix, e.g. "req:key:" (bucket appended at runtime) + limit: int + cost: int # 1 for request counters, token count for token counters + + +@dataclass(frozen=True) +class LimitState: + limit: int + remaining: int + reset_in: int # seconds until the current minute bucket ends + retry_after: int # seconds until a request of this cost would be permitted - def __init__(self, redis_url: Optional[str] = None): + +class LimiterUnavailableError(RuntimeError): + """Raised (internally) when Redis is configured but unreachable and the + operator has not opted into a fallback mode. Middleware translates this + into a 503 response.""" + + +@dataclass(frozen=True) +class RateLimitResult: + allowed: bool + states: Dict[str, LimitState] + # Populated when the operator's fail-mode requires a 503 instead of 429. + fail_closed_reason: Optional[str] = None + + +class RateLimiter: + """Redis-backed minute-bucket rate limiter with explicit outage behavior. + + ``fail_mode``: + * ``"closed"`` — Redis configured but unreachable returns a + ``fail_closed_reason`` so the middleware can reply ``503``. This is + the safe default under multi-replica deployments where a per-replica + local fallback would multiply the effective quota by N replicas. + * ``"open"`` — Redis unreachable → allow the request with no counter + check. Opt-in only. + * ``"local"`` — Redis unreachable → fall back to a per-replica in-memory + counter. Intended for single-replica dev setups; rejected by + ``Settings`` outside debug mode. + + When ``REDIS_URL`` was never configured (development/tests), the limiter + always runs against local in-memory buckets regardless of ``fail_mode``. + """ + + def __init__( + self, + redis_url: Optional[str] = None, + fail_mode: str = "closed", + ): + if fail_mode not in {"closed", "open", "local"}: + raise ValueError(f"invalid rate_limit_fail_mode: {fail_mode!r}") + self.fail_mode = fail_mode + self._redis_url_configured = bool(redis_url) self.redis_client = None self._local_lock = threading.Lock() - self._local_windows: Dict[str, Deque[float]] = {} - self._local_last_seen: Dict[str, float] = {} - self._local_idle_ttl = 3600.0 - self._cleanup_interval = 60.0 - self._last_cleanup = 0.0 + # Map of ":" -> count for the in-memory fallback. + self._local_buckets: Dict[str, int] = {} if redis_url and redis is not None: try: self.redis_client = redis.from_url(redis_url) - # Test connection self.redis_client.ping() - except Exception as e: - logger.warning("Failed to connect to Redis: %s", type(e).__name__) + except Exception as exc: + logger.warning( + "Failed to connect to Redis for rate limiter: %s", + type(exc).__name__, + ) self.redis_client = None elif redis_url and redis is None: - logger.warning("redis package not installed; using in-memory rate limiter") - - def is_allowed(self, key: str, limit: int, window: int) -> bool: - """ - Check if request is allowed using a sliding window counter. - - Args: - key: Unique identifier for the rate limit (e.g., user_id) - limit: Maximum number of requests allowed - window: Time window in seconds + logger.warning( + "redis package not installed; rate limiter degraded to in-memory mode" + ) - Returns: - True if request is allowed, False otherwise - """ - if limit <= 0 or window <= 0: - return False + # ---------------------------------------------------------------- public - if self.redis_client: - try: - return self._is_allowed_redis(key=key, limit=limit, window=window) - except Exception as e: - logger.warning( - "Redis rate limiter unavailable; falling back to in-memory limiter: %s", - type(e).__name__, - ) + def check(self, specs: Sequence[LimitSpec]) -> RateLimitResult: + """Evaluate all dimensions and increment counters on allow.""" + if not specs: + return RateLimitResult(allowed=True, states={}) - return self._is_allowed_local(key=key, limit=limit, window=window) + now = time.time() + bucket = int(now // WINDOW_SECONDS) + elapsed_in_current = now - bucket * WINDOW_SECONDS - def retry_after(self, key: str, limit: int, window: int) -> int: - """Return seconds until the next request should be allowed.""" - if limit <= 0: - return max(1, int(math.ceil(window))) - if window <= 0: - return 1 + if self._use_local(): + return self._check_local(specs, bucket, elapsed_in_current) - if self.redis_client: - try: - return self._retry_after_redis(key=key, limit=limit, window=window) - except Exception as e: - logger.warning( - "Redis rate limiter unavailable; falling back to in-memory retry-after: %s", - type(e).__name__, - ) + if self.redis_client is None: + return self._handle_unavailable(specs, reason="redis-not-connected") - return self._retry_after_local(key=key, limit=limit, window=window) + try: + return self._check_redis(specs, bucket, elapsed_in_current) + except Exception as exc: + logger.warning( + "Redis rate limiter request failed: %s", type(exc).__name__ + ) + return self._handle_unavailable( + specs, reason=f"redis-error:{type(exc).__name__}" + ) def clear(self) -> None: - """Clear in-memory fallback state.""" + """Clear in-memory state. Intended for tests only.""" with self._local_lock: - self._local_windows.clear() - self._local_last_seen.clear() - self._last_cleanup = 0.0 - - def _is_allowed_redis(self, key: str, limit: int, window: int) -> bool: - current_time = time.time() - window_start = current_time - window - member = f"{current_time}:{time.time_ns()}" - - # Use Redis pipeline for atomic operations. - pipe = self.redis_client.pipeline() - pipe.zremrangebyscore(key, 0, window_start) - pipe.zcard(key) - pipe.zadd(key, {member: current_time}) - pipe.expire(key, max(1, int(window))) - - results = pipe.execute() - current_count = int(results[1]) - return current_count < limit - - def _retry_after_redis(self, key: str, limit: int, window: int) -> int: - current_time = time.time() - window_start = current_time - window - - pipe = self.redis_client.pipeline() - pipe.zremrangebyscore(key, 0, window_start) - pipe.zcard(key) - results = pipe.execute() - - current_count = int(results[1]) - if current_count < limit: - return 0 - - next_allowed_index = current_count - limit - next_allowed = self.redis_client.zrange( - key, - next_allowed_index, - next_allowed_index, - withscores=True, + self._local_buckets.clear() + + # ------------------------------------------------------- internal helpers + + def _use_local(self) -> bool: + """Run purely in local mode when no Redis URL was ever configured.""" + return not self._redis_url_configured + + def _handle_unavailable( + self, specs: Sequence[LimitSpec], reason: str + ) -> RateLimitResult: + if self.fail_mode == "open": + # Operator opted into quota-bypass on Redis outage. + return RateLimitResult( + allowed=True, + states={s.name: self._unknown_state(s) for s in specs}, + ) + if self.fail_mode == "local": + now = time.time() + bucket = int(now // WINDOW_SECONDS) + elapsed_in_current = now - bucket * WINDOW_SECONDS + return self._check_local(specs, bucket, elapsed_in_current) + # fail_mode == "closed" — caller translates this to HTTP 503. + return RateLimitResult( + allowed=False, + states={s.name: self._unknown_state(s) for s in specs}, + fail_closed_reason=reason, ) - if not next_allowed: - return 0 - - next_allowed_at = float(next_allowed[0][1]) + window - return max(1, int(math.ceil(next_allowed_at - current_time))) - - def _is_allowed_local(self, key: str, limit: int, window: int) -> bool: - current_time = time.time() - window_start = current_time - window - - with self._local_lock: - self._cleanup_local_state(current_time) - events = self._local_windows.setdefault(key, deque()) - - while events and events[0] <= window_start: - events.popleft() - - self._local_last_seen[key] = current_time - - if len(events) >= limit: - return False - - events.append(current_time) - return True - def _retry_after_local(self, key: str, limit: int, window: int) -> int: - current_time = time.time() - window_start = current_time - window + @staticmethod + def _unknown_state(spec: LimitSpec) -> LimitState: + return LimitState( + limit=spec.limit, remaining=spec.limit, reset_in=WINDOW_SECONDS, retry_after=0 + ) - with self._local_lock: - self._cleanup_local_state(current_time) - events = self._local_windows.get(key) - if not events: + # ---------------------------------------------------------------- Redis + + def _check_redis( + self, + specs: Sequence[LimitSpec], + bucket: int, + elapsed_in_current: float, + ) -> RateLimitResult: + client = self.redis_client + assert client is not None # guarded by caller + + current_keys = [f"{s.key}:{bucket}" for s in specs] + previous_keys = [f"{s.key}:{bucket - 1}" for s in specs] + + pipe = client.pipeline() + for k in current_keys: + pipe.get(k) + for k in previous_keys: + pipe.get(k) + raw = pipe.execute() + + current_counts = [self._parse(v) for v in raw[: len(specs)]] + previous_counts = [self._parse(v) for v in raw[len(specs) :]] + + allowed = True + states: Dict[str, LimitState] = {} + for spec, curr, prev in zip(specs, current_counts, previous_counts): + weighted_before = _weighted(prev, curr, elapsed_in_current) + # Would this request fit under the limit? + projected = weighted_before + spec.cost + state = _state_for(spec, curr, prev, elapsed_in_current, projected) + states[spec.name] = state + if projected > spec.limit: + allowed = False + + if allowed: + # Atomically increment and refresh TTL on the current bucket only. + pipe = client.pipeline() + for key, spec in zip(current_keys, specs): + pipe.incrby(key, spec.cost) + pipe.expire(key, _BUCKET_TTL_SECONDS) + pipe.execute() + + return RateLimitResult(allowed=allowed, states=states) + + @staticmethod + def _parse(raw) -> int: + if raw is None: + return 0 + if isinstance(raw, (bytes, bytearray)): + try: + return int(raw) + except ValueError: return 0 + if isinstance(raw, int): + return raw + try: + return int(raw) + except (TypeError, ValueError): + return 0 - while events and events[0] <= window_start: - events.popleft() - - if not events: - self._local_windows.pop(key, None) - self._local_last_seen.pop(key, None) - return 0 + # ------------------------------------------------------------- in-memory - self._local_last_seen[key] = current_time - if len(events) < limit: - return 0 + def _check_local( + self, + specs: Sequence[LimitSpec], + bucket: int, + elapsed_in_current: float, + ) -> RateLimitResult: + with self._local_lock: + self._gc_local(bucket) + allowed = True + observations: List[Tuple[LimitSpec, int, int]] = [] + for spec in specs: + curr = self._local_buckets.get(f"{spec.key}:{bucket}", 0) + prev = self._local_buckets.get(f"{spec.key}:{bucket - 1}", 0) + observations.append((spec, curr, prev)) + weighted_before = _weighted(prev, curr, elapsed_in_current) + if weighted_before + spec.cost > spec.limit: + allowed = False + + states: Dict[str, LimitState] = {} + for spec, curr, prev in observations: + weighted_before = _weighted(prev, curr, elapsed_in_current) + projected = weighted_before + spec.cost + states[spec.name] = _state_for( + spec, curr, prev, elapsed_in_current, projected + ) - next_allowed_at = events[len(events) - limit] + window - return max(1, int(math.ceil(next_allowed_at - current_time))) + if allowed: + for spec, _curr, _prev in observations: + k = f"{spec.key}:{bucket}" + self._local_buckets[k] = self._local_buckets.get(k, 0) + spec.cost - def _cleanup_local_state(self, current_time: float) -> None: - if current_time - self._last_cleanup < self._cleanup_interval: - return + return RateLimitResult(allowed=allowed, states=states) - expired_keys = [ - key - for key, last_seen in self._local_last_seen.items() - if current_time - last_seen > self._local_idle_ttl + def _gc_local(self, bucket: int) -> None: + """Drop buckets older than the previous one.""" + stale = [ + k + for k in self._local_buckets + if int(k.rsplit(":", 1)[1]) < bucket - 1 ] - for expired_key in expired_keys: - self._local_last_seen.pop(expired_key, None) - self._local_windows.pop(expired_key, None) + for k in stale: + self._local_buckets.pop(k, None) + + +# ---------------------------------------------------------------- helpers + + +def _weighted(prev: int, current: int, elapsed_in_current: float) -> float: + """Sliding-window count over the current 60-second window.""" + if elapsed_in_current >= WINDOW_SECONDS: + return float(current) + ratio = 1.0 - (elapsed_in_current / WINDOW_SECONDS) + return prev * ratio + current + + +def _state_for( + spec: LimitSpec, + current: int, + previous: int, + elapsed_in_current: float, + projected: float, +) -> LimitState: + """Compute the LimitState returned to callers for this dimension. + + ``remaining`` is reported against the sliding window *after* admitting + this request. When the request would be denied, ``retry_after`` is the + number of seconds until the oldest contributing request ages out enough + for ``projected <= limit`` to hold. + """ + reset_in = max(1, int(math.ceil(WINDOW_SECONDS - elapsed_in_current))) + remaining = max(0, int(math.floor(spec.limit - projected))) + + if projected <= spec.limit: + return LimitState( + limit=spec.limit, + remaining=remaining, + reset_in=reset_in, + retry_after=0, + ) + + # Denied: figure out when the sliding weight drops enough to admit + # ``spec.cost`` again. + # + # weighted(t) = previous * (1 - (elapsed + t) / 60) + current + cost + # solve weighted(t) <= limit for t: + # + if previous > 0: + # t such that previous * (1 - (elapsed + t)/60) + current + cost <= limit + # => previous * (elapsed + t) / 60 >= previous + current + cost - limit + # => t >= 60 * (previous + current + cost - limit) / previous - elapsed + required = (previous + current + spec.cost - spec.limit) * WINDOW_SECONDS + t = required / previous - elapsed_in_current + retry_after = max(1, int(math.ceil(t))) + # Capped at reset_in: after the current bucket ends the previous one + # is gone entirely. + retry_after = min(retry_after, reset_in) + else: + # previous is zero → only the current bucket contributes; we must + # wait for it to roll. + retry_after = reset_in + + return LimitState( + limit=spec.limit, + remaining=0, + reset_in=reset_in, + retry_after=retry_after, + ) + + +# ---------------------------------------------------------------- default specs + + +def specs_for_request( + key_id: str, + org_id: Optional[str], + token_cost: int, +) -> List[LimitSpec]: + """Build the standard four-dimension spec list for a single request. + + ``key_id`` and ``org_id`` are opaque identifiers (typically HMAC hashes of + the raw API key, and the org UUID). Token cost should be a positive int; + the caller is responsible for the estimation policy. + """ + token_cost = max(1, int(token_cost)) + out: List[LimitSpec] = [ + LimitSpec( + name="req-key", + key=f"req:key:{key_id}", + limit=settings.rate_limit_requests_per_minute, + cost=1, + ), + LimitSpec( + name="tokens-key", + key=f"tokens:key:{key_id}", + limit=settings.rate_limit_tokens_per_minute, + cost=token_cost, + ), + ] + if org_id: + out.extend( + [ + LimitSpec( + name="req-org", + key=f"req:org:{org_id}", + limit=settings.rate_limit_org_requests_per_minute, + cost=1, + ), + LimitSpec( + name="tokens-org", + key=f"tokens:org:{org_id}", + limit=settings.rate_limit_org_tokens_per_minute, + cost=token_cost, + ), + ] + ) + return out - self._last_cleanup = current_time +# ---------------------------------------------------------------- singleton -# Global rate limiter instance -rate_limiter = RateLimiter(settings.redis_url) +rate_limiter = RateLimiter( + redis_url=settings.redis_url, + fail_mode=settings.rate_limit_fail_mode, +) diff --git a/app/rate_limit_middleware.py b/app/rate_limit_middleware.py new file mode 100644 index 0000000..d4d30ca --- /dev/null +++ b/app/rate_limit_middleware.py @@ -0,0 +1,151 @@ +"""FastAPI middleware: per-key + per-org minute-bucket rate limiting. + +Evaluated before route handlers so a flood of invalid-but-well-formed +requests from a single key cannot bypass the limiter by bailing in +``require_api_key``. Unauthenticated paths (``/api/v1/health``, +``/api/v1/ready``, ``/api/metrics``, ``/docs``, ``/openapi.json``, ``/``) are +allowed through without counter interaction so readiness probes and the +metrics scrape cannot be rate-limited or consume quota. + +On ``fail_closed_reason`` (Redis configured but unreachable under the +``closed`` fail-mode), the middleware replies with HTTP 503. +""" + +from __future__ import annotations + +import logging +import math +import time +from typing import Awaitable, Callable, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response +from sqlalchemy.exc import SQLAlchemyError + +from .key_utils import hash_api_key +from .rate_limit import RateLimitResult, rate_limiter, specs_for_request +from .settings import settings +from .storage import APIKey, SessionLocal + +logger = logging.getLogger(__name__) + + +_UNAUTH_PATHS = frozenset( + { + "/", + "/api/v1/health", + "/api/v1/ready", + "/api/metrics", + "/docs", + "/redoc", + "/openapi.json", + } +) + + +def _tokens_estimate(request: Request) -> int: + """Rough token estimate from Content-Length. + + Real LLM token counts require a tokenizer and the body. For middleware- + level enforcement we approximate ``ceil(bytes / 4)`` — standard rough + heuristic for English text — so per-request budget changes show up on the + token counter before the request reaches the model. Post-response + reconciliation (§1.5d) can refine this later. + """ + raw = request.headers.get("content-length") + if not raw: + return 1 + try: + n = int(raw) + except ValueError: + return 1 + return max(1, math.ceil(n / 4)) + + +def _lookup_org_id(raw_key: str) -> Optional[str]: + """Look up the ``org_id`` for ``raw_key``. Returns None if the key is + unknown — authentication will reject the request downstream.""" + try: + key_hash = hash_api_key(raw_key) + except Exception: # pragma: no cover - defensive + return None + session = SessionLocal() + try: + record = session.query(APIKey).filter(APIKey.key_hash == key_hash).first() + if record is None: + return None + return record.org_id + except SQLAlchemyError as exc: + logger.warning("Rate-limit org lookup failed: %s", type(exc).__name__) + return None + finally: + session.close() + + +def _apply_headers(response: Response, result: RateLimitResult) -> None: + if not result.states: + return + # Report the most restrictive dimension so clients see the real budget. + tightest = min(result.states.values(), key=lambda s: s.remaining) + response.headers["X-RateLimit-Limit"] = str(tightest.limit) + response.headers["X-RateLimit-Remaining"] = str(tightest.remaining) + response.headers["X-RateLimit-Reset"] = str(int(time.time()) + tightest.reset_in) + + +def _retry_after_for(states) -> int: + """Seconds until the most lenient denied dimension would admit again.""" + denied = [s.retry_after for s in states.values() if s.retry_after > 0] + if not denied: + return 1 + return max(1, min(denied)) + + +def install_rate_limit_middleware(app: FastAPI) -> None: + """Register the rate-limit middleware on ``app``.""" + + @app.middleware("http") + async def rate_limit_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + if request.url.path in _UNAUTH_PATHS: + return await call_next(request) + + raw_key = request.headers.get(settings.api_key_header.lower()) + if not raw_key: + # require_api_key will 401. Don't consume quota on missing auth. + return await call_next(request) + + key_hash = hash_api_key(raw_key) + org_id = _lookup_org_id(raw_key) + token_cost = _tokens_estimate(request) + specs = specs_for_request( + key_id=key_hash, org_id=org_id, token_cost=token_cost + ) + + result = rate_limiter.check(specs) + + if result.fail_closed_reason is not None: + logger.warning( + "rate limiter unavailable (fail-closed): %s", + result.fail_closed_reason, + ) + resp = JSONResponse( + status_code=503, + content={"detail": "rate limiter unavailable"}, + ) + resp.headers["Retry-After"] = "1" + return resp + + if not result.allowed: + retry_after = _retry_after_for(result.states) + resp = JSONResponse( + status_code=429, + content={"detail": "rate limit exceeded"}, + headers={"Retry-After": str(retry_after)}, + ) + _apply_headers(resp, result) + return resp + + response = await call_next(request) + _apply_headers(response, result) + return response diff --git a/app/settings.py b/app/settings.py index 2f0bfb8..73131a8 100644 --- a/app/settings.py +++ b/app/settings.py @@ -1,4 +1,5 @@ from typing import Optional +from urllib.parse import urlsplit from pydantic import AliasChoices, Field, model_validator from pydantic_settings import BaseSettings @@ -8,6 +9,15 @@ _DEFAULT_KEY_HMAC_SECRET = "dev-key-hmac-secret-change-in-production" _MIN_SECRET_LENGTH = 32 +# Allowed values for RATE_LIMIT_FAIL_MODE. +# - "closed": deny (HTTP 503) when Redis is configured but unreachable. Safe +# default in multi-replica deployments — a per-replica local fallback would +# multiply the effective quota by N replicas (Cipher review on precheck#31). +# - "open": allow without a counter check. Operator must explicitly accept +# the quota-bypass risk. +# - "local": per-replica in-memory fallback. Intended for single-replica dev. +_RATE_LIMIT_FAIL_MODES = {"closed", "open", "local"} + class Settings(BaseSettings): """Application settings loaded from environment variables""" @@ -22,10 +32,26 @@ class Settings(BaseSettings): validation_alias=AliasChoices("DB_URL", "DATABASE_URL"), ) - # Redis configuration (optional) + # Redis configuration (optional). + # In non-debug environments REDIS_URL must use the TLS scheme (rediss://) + # and carry a password; see _validate_redis_url_posture below. redis_url: Optional[str] = None precheck_allow_cache_ttl_seconds: int = 60 + # Rate limiter behavior on Redis outage. See _RATE_LIMIT_FAIL_MODES. + # Default is "closed" (fail-closed 503) to avoid the per-replica quota- + # bypass described in the Cipher review on precheck#31. Operators running + # a single replica in development may set this to "local". + rate_limit_fail_mode: str = "closed" + + # Default per-minute limits. These are baselines used by the rate-limit + # middleware when no policy override is supplied. Policy-driven overrides + # land in §1.5d. + rate_limit_requests_per_minute: int = 100 + rate_limit_tokens_per_minute: int = 100_000 + rate_limit_org_requests_per_minute: int = 1_000 + rate_limit_org_tokens_per_minute: int = 1_000_000 + # Public base URL for cloud mode public_base: Optional[str] = None @@ -64,6 +90,11 @@ class Settings(BaseSettings): @model_validator(mode="after") def _reject_default_secrets(self) -> "Settings": + if self.rate_limit_fail_mode not in _RATE_LIMIT_FAIL_MODES: + raise ValueError( + f"RATE_LIMIT_FAIL_MODE must be one of {sorted(_RATE_LIMIT_FAIL_MODES)}; " + f"got {self.rate_limit_fail_mode!r}." + ) if not self.debug: self._validate_secret( name="PII_TOKEN_SALT", @@ -80,8 +111,39 @@ def _reject_default_secrets(self) -> "Settings": value=self.key_hmac_secret, default_marker=_DEFAULT_KEY_HMAC_SECRET, ) + self._validate_redis_url_posture() + if self.rate_limit_fail_mode == "local": + raise ValueError( + "RATE_LIMIT_FAIL_MODE=local is only permitted in debug mode; " + "across multiple replicas the per-replica in-memory counter " + "multiplies the effective quota by N. Use 'closed' (default) " + "or explicitly opt into 'open'." + ) return self + def _validate_redis_url_posture(self) -> None: + """Reject plaintext or passwordless REDIS_URL outside debug mode. + + Rate-limit counters, the allow-decision cache, and any future queue + traffic flow through this URL. Plaintext redis:// exposes API-key + fingerprints and quota state on the wire; an unauthenticated Redis + allows any pod in the namespace to read or poison the same counters. + Both are rejected in non-debug environments. + """ + if not self.redis_url: + return + parsed = urlsplit(self.redis_url) + if parsed.scheme != "rediss": + raise ValueError( + "REDIS_URL must use the rediss:// (TLS) scheme outside debug mode; " + f"got scheme {parsed.scheme!r}." + ) + if not parsed.password: + raise ValueError( + "REDIS_URL must include a password outside debug mode; " + "unauthenticated Redis lets any co-tenant read or poison rate-limit counters." + ) + @staticmethod def _validate_secret(name: str, value: str, default_marker: str) -> None: if not value: diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py index 29dfeb0..bb69589 100644 --- a/tests/test_rate_limit.py +++ b/tests/test_rate_limit.py @@ -1,77 +1,240 @@ -from app.rate_limit import RateLimiter +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 GovernsAI. All rights reserved. +"""Unit tests for the minute-bucket sliding-window rate limiter. +Covers the §1.5c requirements directly on ``RateLimiter`` — the HTTP-level +behavior (429, X-RateLimit-* headers) is exercised in test_rate_limit_http.py. +""" -class FailingPipeline: - def zremrangebyscore(self, *_args, **_kwargs): - return self +import pytest + +from app.rate_limit import ( + LimitSpec, + RateLimiter, + WINDOW_SECONDS, + specs_for_request, +) + + +def _specs(*, limit: int, cost: int = 1, key: str = "req:key:k1", name: str = "req-key"): + return [LimitSpec(name=name, key=key, limit=limit, cost=cost)] + + +# ---------------------------------------------------------------- bucketing + + +def test_counter_increments_per_request(monkeypatch): + limiter = RateLimiter(redis_url=None) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + + r1 = limiter.check(_specs(limit=3)) + r2 = limiter.check(_specs(limit=3)) + r3 = limiter.check(_specs(limit=3)) + r4 = limiter.check(_specs(limit=3)) + + assert [r.allowed for r in (r1, r2, r3, r4)] == [True, True, True, False] + assert [r.states["req-key"].remaining for r in (r1, r2, r3)] == [2, 1, 0] + + +def test_counter_resets_after_minute_window(monkeypatch): + limiter = RateLimiter(redis_url=None) + now = [1000.0] + monkeypatch.setattr("app.rate_limit.time.time", lambda: now[0]) + + assert limiter.check(_specs(limit=1)).allowed is True + assert limiter.check(_specs(limit=1)).allowed is False + + # Advance exactly one full window — previous bucket's weight drops to + # zero because elapsed_in_current == 0 gives ratio 1, but we are now in + # the next bucket entirely. + now[0] = 1000.0 + WINDOW_SECONDS * 2 # skip prev-bucket entirely + + assert limiter.check(_specs(limit=1)).allowed is True - def zcard(self, *_args, **_kwargs): + +def test_partial_window_applies_sliding_weight(monkeypatch): + """A full previous bucket halves its contribution after 30s into the next. + + At t=1000 bucket=16 (1000%60=40, elapsed=40). Fill it completely. + At t=1060 bucket=17, elapsed=20; previous weight = 50 * (1 - 20/60) ≈ 33. + Admitting 67 more requests (33 + 67 = 100) should succeed; 68th denies. + """ + limiter = RateLimiter(redis_url=None) + now = [960.0] # bucket 16 start; elapsed_in_current=0 + monkeypatch.setattr("app.rate_limit.time.time", lambda: now[0]) + + for _ in range(50): + assert limiter.check(_specs(limit=50)).allowed is True + + # bucket 17, 20s in — previous contribution ≈ 50 * (40/60) = 33.33 + now[0] = 1040.0 + allowed = 0 + for _ in range(200): + if limiter.check(_specs(limit=50)).allowed: + allowed += 1 + else: + break + # 50 - 33.33 = 16.67 → floor allows 16 more this bucket. + assert allowed == 16 + + +# ---------------------------------------------------------------- dimensions + + +def test_per_key_and_per_org_counters_are_independent(monkeypatch): + """Same org, different API keys — per-key is cheap, per-org shared.""" + limiter = RateLimiter(redis_url=None) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + + def specs(key_id: str, org_id: str, req_limit: int, org_limit: int): + return [ + LimitSpec( + name="req-key", key=f"req:key:{key_id}", limit=req_limit, cost=1 + ), + LimitSpec( + name="req-org", key=f"req:org:{org_id}", limit=org_limit, cost=1 + ), + ] + + # Key A exhausts its per-key limit (2) but stays under the per-org limit (10). + assert limiter.check(specs("A", "org1", 2, 10)).allowed is True + assert limiter.check(specs("A", "org1", 2, 10)).allowed is True + denied = limiter.check(specs("A", "org1", 2, 10)) + assert denied.allowed is False + assert denied.states["req-key"].remaining == 0 + # Per-org dim is not the blocker — the blocker is per-key. + assert denied.states["req-org"].remaining > 0 + + # Key B in the same org can still proceed on its own per-key counter. + assert limiter.check(specs("B", "org1", 2, 10)).allowed is True + + +def test_org_limit_denies_even_when_per_key_allows(monkeypatch): + limiter = RateLimiter(redis_url=None) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + + # Saturate the per-org counter via key A (per-key limit is generous). + org_specs = lambda key_id: [ # noqa: E731 + LimitSpec(name="req-key", key=f"req:key:{key_id}", limit=10, cost=1), + LimitSpec(name="req-org", key="req:org:org1", limit=2, cost=1), + ] + assert limiter.check(org_specs("A")).allowed is True + assert limiter.check(org_specs("A")).allowed is True + + # Key B is fresh on per-key but blocked by shared per-org counter. + result = limiter.check(org_specs("B")) + assert result.allowed is False + assert result.states["req-org"].remaining == 0 + assert result.states["req-key"].remaining > 0 + + +def test_token_cost_applied_to_token_counters(monkeypatch): + limiter = RateLimiter(redis_url=None) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + + specs = lambda cost: [ # noqa: E731 + LimitSpec(name="tokens-key", key="tokens:key:x", limit=100, cost=cost), + ] + assert limiter.check(specs(60)).allowed is True + # Second request at cost=60 would push to 120, over the 100 limit. + result = limiter.check(specs(60)) + assert result.allowed is False + + +# ------------------------------------------------------ specs_for_request + + +def test_specs_for_request_omits_org_when_none(): + specs = specs_for_request(key_id="kh", org_id=None, token_cost=10) + names = {s.name for s in specs} + assert names == {"req-key", "tokens-key"} + + +def test_specs_for_request_includes_org_when_provided(): + specs = specs_for_request(key_id="kh", org_id="org1", token_cost=10) + names = {s.name for s in specs} + assert names == {"req-key", "tokens-key", "req-org", "tokens-org"} + + +# ----------------------------------------------------------- fail modes + + +class _FailingPipeline: + def get(self, *_a, **_kw): return self - def zadd(self, *_args, **_kwargs): + def incrby(self, *_a, **_kw): return self - def expire(self, *_args, **_kwargs): + def expire(self, *_a, **_kw): return self def execute(self): raise RuntimeError("redis unavailable") -class FailingRedis: +class _FailingRedis: def pipeline(self): - return FailingPipeline() + return _FailingPipeline() -def test_in_memory_fallback_enforces_limit_without_redis(): - limiter = RateLimiter(redis_url=None) +def _install_failing_redis(limiter: RateLimiter) -> None: + limiter._redis_url_configured = True + limiter.redis_client = _FailingRedis() - assert limiter.is_allowed("user-a", limit=2, window=60) is True - assert limiter.is_allowed("user-a", limit=2, window=60) is True - assert limiter.is_allowed("user-a", limit=2, window=60) is False +def test_fail_closed_on_redis_outage_returns_fail_closed_reason(monkeypatch): + limiter = RateLimiter(redis_url=None, fail_mode="closed") + _install_failing_redis(limiter) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) -def test_in_memory_fallback_enforces_limit_when_redis_errors(): - limiter = RateLimiter(redis_url=None) - limiter.redis_client = FailingRedis() + result = limiter.check(_specs(limit=10)) + assert result.allowed is False + assert result.fail_closed_reason is not None + assert "redis-error" in result.fail_closed_reason - assert limiter.is_allowed("user-b", limit=1, window=60) is True - assert limiter.is_allowed("user-b", limit=1, window=60) is False +def test_fail_open_on_redis_outage_allows_request(monkeypatch): + limiter = RateLimiter(redis_url=None, fail_mode="open") + _install_failing_redis(limiter) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) -def test_in_memory_fallback_resets_after_window(monkeypatch): - limiter = RateLimiter(redis_url=None) - now = [1000.0] + result = limiter.check(_specs(limit=1)) + assert result.allowed is True + assert result.fail_closed_reason is None - monkeypatch.setattr("app.rate_limit.time.time", lambda: now[0]) - assert limiter.is_allowed("user-c", limit=1, window=10) is True - assert limiter.is_allowed("user-c", limit=1, window=10) is False +def test_fail_local_on_redis_outage_uses_in_memory(monkeypatch): + limiter = RateLimiter(redis_url=None, fail_mode="local") + _install_failing_redis(limiter) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) - now[0] = 1011.0 - assert limiter.is_allowed("user-c", limit=1, window=10) is True + assert limiter.check(_specs(limit=1)).allowed is True + assert limiter.check(_specs(limit=1)).allowed is False -def test_clear_resets_in_memory_fallback_state(): - limiter = RateLimiter(redis_url=None) +def test_no_redis_url_configured_uses_local_regardless_of_fail_mode(monkeypatch): + """When REDIS_URL was never configured (dev/tests) the limiter runs + locally and never takes the fail-closed path.""" + limiter = RateLimiter(redis_url=None, fail_mode="closed") + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) - assert limiter.is_allowed("user-d", limit=1, window=60) is True - assert limiter.is_allowed("user-d", limit=1, window=60) is False + assert limiter.check(_specs(limit=1)).allowed is True - limiter.clear() - assert limiter.is_allowed("user-d", limit=1, window=60) is True +def test_clear_resets_in_memory_state(monkeypatch): + limiter = RateLimiter(redis_url=None) + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + assert limiter.check(_specs(limit=1)).allowed is True + assert limiter.check(_specs(limit=1)).allowed is False + limiter.clear() + assert limiter.check(_specs(limit=1)).allowed is True -def test_retry_after_uses_sliding_window(monkeypatch): - limiter = RateLimiter(redis_url=None) - now = [1000.0] - monkeypatch.setattr("app.rate_limit.time.time", lambda: now[0]) +# --------------------------------------------------------- rejected configs - assert limiter.is_allowed("user-e", limit=2, window=10) is True - assert limiter.is_allowed("user-e", limit=2, window=10) is True - now[0] = 1004.0 - assert limiter.is_allowed("user-e", limit=2, window=10) is False - assert limiter.retry_after("user-e", limit=2, window=10) == 6 +def test_invalid_fail_mode_rejected(): + with pytest.raises(ValueError, match="rate_limit_fail_mode"): + RateLimiter(redis_url=None, fail_mode="nonsense") diff --git a/tests/test_rate_limit_http.py b/tests/test_rate_limit_http.py index 3f7a75c..a5c3db5 100644 --- a/tests/test_rate_limit_http.py +++ b/tests/test_rate_limit_http.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024 GovernsAI. All rights reserved. -"""T-3 HTTP-level 429 integration tests for rate limiting.""" +"""HTTP-level tests for the rate-limit middleware (§1.5c). + +Verifies: + * 429 fires when the minute-bucket limit is exceeded + * Retry-After + X-RateLimit-* headers populate correctly + * Counter resets after the minute window rolls + * Token-count limit denies requests independently of the req/min counter +""" import pytest from app.rate_limit import rate_limiter +from app.settings import settings PRECHECK_URL = "/api/v1/precheck" POSTCHECK_URL = "/api/v1/postcheck" @@ -19,12 +27,8 @@ @pytest.fixture(autouse=True) def _reset_rate_limiter(): - """Clear the in-memory rate limiter state before each test. - - The rate limiter is a module-level singleton. Without this, request - counts from other tests in the same process accumulate and trip the - limit before the 100-request mark. - """ + """The rate limiter is a module-level singleton; clear bucket state + around each test so counts from other tests don't leak across.""" rate_limiter.clear() yield rate_limiter.clear() @@ -32,9 +36,10 @@ def _reset_rate_limiter(): @pytest.mark.parametrize("endpoint", RATE_LIMITED_ENDPOINTS) def test_rate_limit_returns_429_after_100_requests( - endpoint, test_client, active_api_key + endpoint, test_client, active_api_key, monkeypatch ): - """First 100 requests must succeed; the 101st must return 429.""" + """First 100 requests in a minute bucket succeed; the 101st returns 429.""" + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) headers = {"X-Governs-Key": active_api_key.key} for i in range(1, 101): @@ -43,18 +48,15 @@ def test_rate_limit_returns_429_after_100_requests( resp.status_code == 200 ), f"Expected 200 on request {i}, got {resp.status_code}: {resp.text}" - # 101st request must be rate-limited resp = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) - assert ( - resp.status_code == 429 - ), f"Expected 429 on request 101, got {resp.status_code}: {resp.text}" + assert resp.status_code == 429, f"Expected 429, got {resp.status_code}" @pytest.mark.parametrize("endpoint", RATE_LIMITED_ENDPOINTS) def test_rate_limit_response_has_retry_after_header( - endpoint, test_client, active_api_key + endpoint, test_client, active_api_key, monkeypatch ): - """The 429 response must include a Retry-After header.""" + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) headers = {"X-Governs-Key": active_api_key.key} for _ in range(100): @@ -62,31 +64,86 @@ def test_rate_limit_response_has_retry_after_header( resp = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) assert resp.status_code == 429 - assert "retry-after" in { - k.lower() for k in resp.headers - }, f"Retry-After header missing from 429 response. Headers: {dict(resp.headers)}" + assert "retry-after" in {k.lower() for k in resp.headers} + assert int(resp.headers["retry-after"]) >= 1 + + +@pytest.mark.parametrize("endpoint", RATE_LIMITED_ENDPOINTS) +def test_successful_response_carries_x_ratelimit_headers( + endpoint, test_client, active_api_key, monkeypatch +): + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + headers = {"X-Governs-Key": active_api_key.key} + + resp = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) + assert resp.status_code == 200 + assert "x-ratelimit-limit" in {k.lower() for k in resp.headers} + assert "x-ratelimit-remaining" in {k.lower() for k in resp.headers} + assert "x-ratelimit-reset" in {k.lower() for k in resp.headers} @pytest.mark.parametrize("endpoint", RATE_LIMITED_ENDPOINTS) -def test_rate_limit_retry_after_matches_sliding_window( +def test_x_ratelimit_remaining_decreases_across_requests( endpoint, test_client, active_api_key, monkeypatch ): - """Retry-After should reflect when the oldest in-window request expires.""" + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) headers = {"X-Governs-Key": active_api_key.key} + + r1 = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) + r2 = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) + rem1 = int(r1.headers["x-ratelimit-remaining"]) + rem2 = int(r2.headers["x-ratelimit-remaining"]) + assert rem2 < rem1 + + +@pytest.mark.parametrize("endpoint", RATE_LIMITED_ENDPOINTS) +def test_429_when_minute_bucket_rolls_admits_new_requests( + endpoint, test_client, active_api_key, monkeypatch +): + """After two full minutes the previous bucket's contribution is gone, so + fresh requests are admitted again.""" now = [1000.0] monkeypatch.setattr("app.rate_limit.time.time", lambda: now[0]) + headers = {"X-Governs-Key": active_api_key.key} - for _ in range(50): - resp = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) - assert resp.status_code == 200 - - now[0] = 1030.0 - for _ in range(50): - resp = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) - assert resp.status_code == 200 + # Fill the current bucket to the limit. + for _ in range(100): + test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) + resp = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) + assert resp.status_code == 429 + # Jump past two full windows — previous bucket is gone entirely. + now[0] = 1000.0 + 120.0 resp = test_client.post(endpoint, json=VALID_PAYLOAD, headers=headers) + assert resp.status_code == 200 + + +def test_token_limit_triggers_429(test_client, active_api_key, monkeypatch): + """A single large-content request exceeds the configured token limit.""" + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + # Shrink the token budget so one request trips it; content_length gives + # the estimate so we don't need a huge body. + monkeypatch.setattr(settings, "rate_limit_tokens_per_minute", 10) + + headers = {"X-Governs-Key": active_api_key.key} + big_payload = {**VALID_PAYLOAD, "raw_text": "x" * 2000} # ~500 tokens + resp = test_client.post(PRECHECK_URL, json=big_payload, headers=headers) assert resp.status_code == 429 + assert "retry-after" in {k.lower() for k in resp.headers} + + +def test_rate_limit_skipped_for_health_endpoint(test_client, active_api_key, monkeypatch): + """Health probes must not interact with the counter.""" + monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) + # Fill the key's per-key counter via a precheck endpoint, then confirm + # /api/v1/health still returns 200. + headers = {"X-Governs-Key": active_api_key.key} + for _ in range(100): + test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) assert ( - resp.headers["retry-after"] == "30" - ), f"Expected Retry-After: 30, got: {resp.headers.get('retry-after')}" + test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers).status_code + == 429 + ) + + health = test_client.get("/api/v1/health") + assert health.status_code == 200 diff --git a/tests/test_settings.py b/tests/test_settings.py index 3397642..b1d45e4 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -30,6 +30,8 @@ def _set_non_debug_safe_env(monkeypatch): monkeypatch.setenv("WEBHOOK_SECRET", "w" * 32) monkeypatch.setenv("PII_TOKEN_SALT", "p" * 32) monkeypatch.setenv("KEY_HMAC_SECRET", "k" * 32) + monkeypatch.delenv("REDIS_URL", raising=False) + monkeypatch.delenv("RATE_LIMIT_FAIL_MODE", raising=False) @pytest.mark.parametrize( @@ -62,3 +64,92 @@ def test_settings_reject_default_non_debug_secret_markers(monkeypatch, env_var, with pytest.raises(ValueError, match=env_var): Settings(_env_file=None) + + +# --------------------------------------------------------- REDIS_URL posture + + +def test_settings_reject_plaintext_redis_url_outside_debug(monkeypatch): + _set_non_debug_safe_env(monkeypatch) + monkeypatch.setenv("REDIS_URL", "redis://:secret@redis.internal:6379/0") + + with pytest.raises(ValueError, match="REDIS_URL.*rediss"): + Settings(_env_file=None) + + +def test_settings_reject_passwordless_redis_url_outside_debug(monkeypatch): + _set_non_debug_safe_env(monkeypatch) + monkeypatch.setenv("REDIS_URL", "rediss://redis.internal:6379/0") + + with pytest.raises(ValueError, match="REDIS_URL.*password"): + Settings(_env_file=None) + + +def test_settings_accept_tls_password_redis_url_outside_debug(monkeypatch): + _set_non_debug_safe_env(monkeypatch) + monkeypatch.setenv("REDIS_URL", "rediss://:secret@redis.internal:6379/0") + + s = Settings(_env_file=None) + + assert s.redis_url == "rediss://:secret@redis.internal:6379/0" + + +def test_settings_accept_plaintext_redis_url_in_debug(monkeypatch): + monkeypatch.setenv("DEBUG", "true") + monkeypatch.setenv("DATABASE_URL", "sqlite:///./debug.db") + monkeypatch.delenv("DB_URL", raising=False) + monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0") + + s = Settings(_env_file=None) + + assert s.redis_url == "redis://localhost:6379/0" + + +def test_settings_accept_unset_redis_url(monkeypatch): + """REDIS_URL may be omitted entirely; the posture validator only applies + when a URL is configured.""" + _set_non_debug_safe_env(monkeypatch) + + s = Settings(_env_file=None) + + assert s.redis_url is None + + +# --------------------------------------------------- RATE_LIMIT_FAIL_MODE + + +def test_settings_reject_invalid_fail_mode(monkeypatch): + _set_non_debug_safe_env(monkeypatch) + monkeypatch.setenv("RATE_LIMIT_FAIL_MODE", "teapot") + + with pytest.raises(ValueError, match="RATE_LIMIT_FAIL_MODE"): + Settings(_env_file=None) + + +def test_settings_reject_local_fail_mode_outside_debug(monkeypatch): + """`local` multiplies the effective quota by N replicas — reject outside + debug mode (Cipher review on precheck#31).""" + _set_non_debug_safe_env(monkeypatch) + monkeypatch.setenv("RATE_LIMIT_FAIL_MODE", "local") + + with pytest.raises(ValueError, match="RATE_LIMIT_FAIL_MODE=local"): + Settings(_env_file=None) + + +def test_settings_accept_local_fail_mode_in_debug(monkeypatch): + monkeypatch.setenv("DEBUG", "true") + monkeypatch.setenv("DATABASE_URL", "sqlite:///./debug.db") + monkeypatch.delenv("DB_URL", raising=False) + monkeypatch.setenv("RATE_LIMIT_FAIL_MODE", "local") + + s = Settings(_env_file=None) + + assert s.rate_limit_fail_mode == "local" + + +def test_settings_default_fail_mode_is_closed(monkeypatch): + _set_non_debug_safe_env(monkeypatch) + + s = Settings(_env_file=None) + + assert s.rate_limit_fail_mode == "closed" From 4e2f91589b66f5293d9ec6343d87aae49d850d5c Mon Sep 17 00:00:00 2001 From: Shaishav Pidadi Date: Sat, 25 Apr 2026 01:02:52 -0400 Subject: [PATCH 2/2] fix(rate-limit): apply black/isort and resolve mypy type error - black/isort formatting on rate_limit middleware + tests - cast SQLAlchemy Column[str] to str for org_id return type --- app/rate_limit.py | 17 ++++++++--------- app/rate_limit_middleware.py | 7 +++---- tests/test_rate_limit.py | 14 ++++++-------- tests/test_rate_limit_http.py | 4 +++- 4 files changed, 20 insertions(+), 22 deletions(-) diff --git a/app/rate_limit.py b/app/rate_limit.py index 633263e..85d14f0 100644 --- a/app/rate_limit.py +++ b/app/rate_limit.py @@ -47,7 +47,7 @@ class LimitSpec: """One rate-limit dimension to evaluate on this request.""" name: str # e.g. "req-key", "req-org", "tokens-key", "tokens-org" - key: str # Redis key prefix, e.g. "req:key:" (bucket appended at runtime) + key: str # Redis key prefix, e.g. "req:key:" (bucket appended at runtime) limit: int cost: int # 1 for request counters, token count for token counters @@ -56,7 +56,7 @@ class LimitSpec: class LimitState: limit: int remaining: int - reset_in: int # seconds until the current minute bucket ends + reset_in: int # seconds until the current minute bucket ends retry_after: int # seconds until a request of this cost would be permitted @@ -141,9 +141,7 @@ def check(self, specs: Sequence[LimitSpec]) -> RateLimitResult: try: return self._check_redis(specs, bucket, elapsed_in_current) except Exception as exc: - logger.warning( - "Redis rate limiter request failed: %s", type(exc).__name__ - ) + logger.warning("Redis rate limiter request failed: %s", type(exc).__name__) return self._handle_unavailable( specs, reason=f"redis-error:{type(exc).__name__}" ) @@ -183,7 +181,10 @@ def _handle_unavailable( @staticmethod def _unknown_state(spec: LimitSpec) -> LimitState: return LimitState( - limit=spec.limit, remaining=spec.limit, reset_in=WINDOW_SECONDS, retry_after=0 + limit=spec.limit, + remaining=spec.limit, + reset_in=WINDOW_SECONDS, + retry_after=0, ) # ---------------------------------------------------------------- Redis @@ -285,9 +286,7 @@ def _check_local( def _gc_local(self, bucket: int) -> None: """Drop buckets older than the previous one.""" stale = [ - k - for k in self._local_buckets - if int(k.rsplit(":", 1)[1]) < bucket - 1 + k for k in self._local_buckets if int(k.rsplit(":", 1)[1]) < bucket - 1 ] for k in stale: self._local_buckets.pop(k, None) diff --git a/app/rate_limit_middleware.py b/app/rate_limit_middleware.py index d4d30ca..96a61c8 100644 --- a/app/rate_limit_middleware.py +++ b/app/rate_limit_middleware.py @@ -74,7 +74,8 @@ def _lookup_org_id(raw_key: str) -> Optional[str]: record = session.query(APIKey).filter(APIKey.key_hash == key_hash).first() if record is None: return None - return record.org_id + org_id = record.org_id + return str(org_id) if org_id is not None else None except SQLAlchemyError as exc: logger.warning("Rate-limit org lookup failed: %s", type(exc).__name__) return None @@ -118,9 +119,7 @@ async def rate_limit_middleware( key_hash = hash_api_key(raw_key) org_id = _lookup_org_id(raw_key) token_cost = _tokens_estimate(request) - specs = specs_for_request( - key_id=key_hash, org_id=org_id, token_cost=token_cost - ) + specs = specs_for_request(key_id=key_hash, org_id=org_id, token_cost=token_cost) result = rate_limiter.check(specs) diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py index bb69589..be93c14 100644 --- a/tests/test_rate_limit.py +++ b/tests/test_rate_limit.py @@ -9,14 +9,16 @@ import pytest from app.rate_limit import ( + WINDOW_SECONDS, LimitSpec, RateLimiter, - WINDOW_SECONDS, specs_for_request, ) -def _specs(*, limit: int, cost: int = 1, key: str = "req:key:k1", name: str = "req-key"): +def _specs( + *, limit: int, cost: int = 1, key: str = "req:key:k1", name: str = "req-key" +): return [LimitSpec(name=name, key=key, limit=limit, cost=cost)] @@ -88,12 +90,8 @@ def test_per_key_and_per_org_counters_are_independent(monkeypatch): def specs(key_id: str, org_id: str, req_limit: int, org_limit: int): return [ - LimitSpec( - name="req-key", key=f"req:key:{key_id}", limit=req_limit, cost=1 - ), - LimitSpec( - name="req-org", key=f"req:org:{org_id}", limit=org_limit, cost=1 - ), + LimitSpec(name="req-key", key=f"req:key:{key_id}", limit=req_limit, cost=1), + LimitSpec(name="req-org", key=f"req:org:{org_id}", limit=org_limit, cost=1), ] # Key A exhausts its per-key limit (2) but stays under the per-org limit (10). diff --git a/tests/test_rate_limit_http.py b/tests/test_rate_limit_http.py index a5c3db5..a8423cc 100644 --- a/tests/test_rate_limit_http.py +++ b/tests/test_rate_limit_http.py @@ -132,7 +132,9 @@ def test_token_limit_triggers_429(test_client, active_api_key, monkeypatch): assert "retry-after" in {k.lower() for k in resp.headers} -def test_rate_limit_skipped_for_health_endpoint(test_client, active_api_key, monkeypatch): +def test_rate_limit_skipped_for_health_endpoint( + test_client, active_api_key, monkeypatch +): """Health probes must not interact with the counter.""" monkeypatch.setattr("app.rate_limit.time.time", lambda: 1000.0) # Fill the key's per-key counter via a precheck endpoint, then confirm