From 5a112717d4edc10f99fafd1cdf07d52dcde11c40 Mon Sep 17 00:00:00 2001 From: Shaishav Pidadi Date: Thu, 23 Apr 2026 16:29:25 -0400 Subject: [PATCH 1/2] feat(precheck): cache identical allow decisions Refs: 77df8824-f4b8-4160-8570-9a9ac16d1fa1 --- PROJECT_SPECS.md | 2 + app/api.py | 112 ++++++++++++++++-------- app/decision_cache.py | 132 +++++++++++++++++++++++++++++ app/settings.py | 1 + env.example | 1 + tests/test_precheck_allow_cache.py | 127 +++++++++++++++++++++++++++ 6 files changed, 339 insertions(+), 36 deletions(-) create mode 100644 app/decision_cache.py create mode 100644 tests/test_precheck_allow_cache.py diff --git a/PROJECT_SPECS.md b/PROJECT_SPECS.md index 8a7de4f..6a6e0ec 100644 --- a/PROJECT_SPECS.md +++ b/PROJECT_SPECS.md @@ -32,6 +32,7 @@ GovernsAI Precheck is a policy evaluation and PII redaction service that provide - **Dead Letter Queue (DLQ)**: Failed webhook deliveries stored in JSONL format - **Retry logic**: Exponential backoff with configurable retry attempts - **Event schema**: Versioned event format for backward compatibility +- **Allow-decision cache**: Redis-first cache for identical `allow` decisions with a short TTL ### 5. Failure Contract & Error Handling - **Configurable error behavior**: `block`, `pass`, or `best_effort` modes @@ -194,6 +195,7 @@ GET /api/metrics ### Standard Response Headers - **`X-Request-ID`**: Unique UUID generated for each request for trace correlation - **`X-Response-Time-Ms`**: Integer request duration in milliseconds added to every response +- **`X-Cache`**: Cache outcome for `/api/v1/precheck` responses (`HIT` or `MISS`) ### Precheck Endpoint ``` diff --git a/app/api.py b/app/api.py index 9518efe..4a8a551 100644 --- a/app/api.py +++ b/app/api.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from .auth import AuthContext, require_api_key +from .decision_cache import allow_decision_cache from .events import emit_event from .log import audit_log from .metrics import ( @@ -42,6 +43,28 @@ def _ensure_correlation_id(corr_id: Optional[str]) -> str: return corr_id or f"corr-{secrets.token_hex(12)}" +def _build_allow_cache_key(req: PrePostCheckRequest, org_id: Optional[str]) -> str: + policy_version = req.policy_config.version if req.policy_config else "legacy" + payload = { + "content": req.raw_text, + "tool": req.tool, + "policy_version": policy_version, + } + digest = hashlib.sha256( + json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + ).hexdigest() + return f"precheck:allow:{org_id or 'no-org'}:{digest}" + + +def _is_cacheable_allow_result(req: PrePostCheckRequest, result: dict) -> bool: + return ( + req.budget_context is None + and result.get("decision") == "allow" + and result.get("budget_status") is None + and result.get("budget_info") is None + ) + + def extract_pii_info_from_reasons( reasons: Optional[List[str]], ) -> Tuple[List[str], float]: @@ -272,7 +295,9 @@ async def metrics(): @router.post("/v1/precheck", response_model=DecisionResponse) async def precheck( - req: PrePostCheckRequest, auth: AuthContext = Depends(require_api_key) + req: PrePostCheckRequest, + response: Response, + auth: AuthContext = Depends(require_api_key), ): """Precheck endpoint for policy evaluation and PII redaction""" api_key = auth.raw_key @@ -303,50 +328,65 @@ async def precheck( start_time = time.time() start_ts = int(start_time) + cache_key = _build_allow_cache_key(req, org_id) try: logger.debug( "precheck request", extra={"tool": req.tool, "corr_id": correlation_id} ) - # Use new policy evaluation with payload policies - policy_config = req.policy_config.model_dump() if req.policy_config else None - tool_config = req.tool_config.model_dump() if req.tool_config else None - budget_context = req.budget_context.model_dump() if req.budget_context else None - result = evaluate_with_payload_policy( - tool=req.tool, - scope=req.scope, - raw_text=req.raw_text, - now=start_ts, - direction="ingress", - policy_config=policy_config, - tool_config=tool_config, - user_id=user_id, - budget_context=budget_context, - ) - - # Add budget info to result if not already present - if user_id and tool_config and policy_config and budget_context: - from .policies import _add_budget_info_to_result + cached_result = allow_decision_cache.get(cache_key) + if cached_result is not None: + response.headers["X-Cache"] = "HIT" + result = cached_result + else: + response.headers["X-Cache"] = "MISS" - result = _add_budget_info_to_result( - result, - user_id, - req.tool, - req.raw_text, - tool_config, - policy_config, - budget_context, + # Use new policy evaluation with payload policies + policy_config = ( + req.policy_config.model_dump() if req.policy_config else None + ) + tool_config = req.tool_config.model_dump() if req.tool_config else None + budget_context = ( + req.budget_context.model_dump() if req.budget_context else None + ) + result = evaluate_with_payload_policy( + tool=req.tool, + scope=req.scope, + raw_text=req.raw_text, + now=start_ts, + direction="ingress", + policy_config=policy_config, + tool_config=tool_config, + user_id=user_id, + budget_context=budget_context, ) - # Metrics: Record policy evaluation - policy_eval_duration = time.time() - start_time - record_policy_evaluation( - tool=req.tool, - direction="ingress", - policy_id=result.get("policy_id", "unknown"), - duration=policy_eval_duration, - ) + # Add budget info to result if not already present + if user_id and tool_config and policy_config and budget_context: + from .policies import _add_budget_info_to_result + + result = _add_budget_info_to_result( + result, + user_id, + req.tool, + req.raw_text, + tool_config, + policy_config, + budget_context, + ) + + if _is_cacheable_allow_result(req, result): + allow_decision_cache.set(cache_key, result) + + # Metrics: Record policy evaluation + policy_eval_duration = time.time() - start_time + record_policy_evaluation( + tool=req.tool, + direction="ingress", + policy_id=result.get("policy_id", "unknown"), + duration=policy_eval_duration, + ) # Extract PII information from reasons pii_types, confidence = extract_pii_info_from_reasons(result.get("reasons", [])) diff --git a/app/decision_cache.py b/app/decision_cache.py new file mode 100644 index 0000000..93d2d1f --- /dev/null +++ b/app/decision_cache.py @@ -0,0 +1,132 @@ +import json +import logging +import threading +import time +from typing import Dict, Optional, Tuple + +from .settings import settings + +logger = logging.getLogger(__name__) + +try: + import redis +except Exception: # pragma: no cover - exercised in environments without redis package + redis = None + + +class AllowDecisionCache: + """Redis-first cache for cacheable allow decisions.""" + + def __init__(self, redis_url: Optional[str] = None, ttl_seconds: int = 60): + self.redis_client = None + self.ttl_seconds = max(0, int(ttl_seconds)) + self._local_lock = threading.Lock() + self._local_store: Dict[str, Tuple[float, str]] = {} + self._cleanup_interval = 60.0 + self._last_cleanup = 0.0 + + if redis_url and redis is not None: + try: + self.redis_client = redis.from_url(redis_url) + self.redis_client.ping() + except Exception as exc: + logger.warning( + "Failed to connect to Redis for allow-decision cache: %s", + type(exc).__name__, + ) + self.redis_client = None + elif redis_url and redis is None: + logger.warning( + "redis package not installed; using in-memory allow-decision cache" + ) + + def get(self, key: str) -> Optional[Dict]: + if self.ttl_seconds <= 0: + return None + + if self.redis_client: + try: + return self._get_redis(key) + except Exception as exc: + logger.warning( + "Redis allow-decision cache unavailable; falling back to in-memory cache: %s", + type(exc).__name__, + ) + + return self._get_local(key) + + def set(self, key: str, value: Dict) -> None: + if self.ttl_seconds <= 0: + return + + payload = json.dumps(value) + + if self.redis_client: + try: + self.redis_client.setex(key, self.ttl_seconds, payload) + return + except Exception as exc: + logger.warning( + "Redis allow-decision cache unavailable; falling back to in-memory cache: %s", + type(exc).__name__, + ) + + self._set_local(key, payload) + + def clear(self) -> None: + with self._local_lock: + self._local_store.clear() + self._last_cleanup = 0.0 + + def _get_redis(self, key: str) -> Optional[Dict]: + payload = self.redis_client.get(key) + if payload is None: + return None + if isinstance(payload, bytes): + payload = payload.decode("utf-8") + return json.loads(payload) + + def _get_local(self, key: str) -> Optional[Dict]: + current_time = time.time() + with self._local_lock: + self._cleanup_local_state(current_time) + item = self._local_store.get(key) + if item is None: + return None + + expires_at, payload = item + if expires_at <= current_time: + self._local_store.pop(key, None) + return None + + try: + return json.loads(payload) + except json.JSONDecodeError: + self._local_store.pop(key, None) + return None + + def _set_local(self, key: str, payload: str) -> None: + current_time = time.time() + expires_at = current_time + self.ttl_seconds + with self._local_lock: + self._cleanup_local_state(current_time) + self._local_store[key] = (expires_at, payload) + + def _cleanup_local_state(self, current_time: float) -> None: + if current_time - self._last_cleanup < self._cleanup_interval: + return + + expired_keys = [ + key + for key, (expires_at, _payload) in self._local_store.items() + if expires_at <= current_time + ] + for key in expired_keys: + self._local_store.pop(key, None) + + self._last_cleanup = current_time + + +allow_decision_cache = AllowDecisionCache( + settings.redis_url, settings.precheck_allow_cache_ttl_seconds +) diff --git a/app/settings.py b/app/settings.py index cc84bed..181897d 100644 --- a/app/settings.py +++ b/app/settings.py @@ -19,6 +19,7 @@ class Settings(BaseSettings): # Redis configuration (optional) redis_url: Optional[str] = None + precheck_allow_cache_ttl_seconds: int = 60 # Public base URL for cloud mode public_base: Optional[str] = None diff --git a/env.example b/env.example index a6af17f..d18f1f3 100644 --- a/env.example +++ b/env.example @@ -10,6 +10,7 @@ DB_URL=sqlite:///./local.db # Redis Configuration (optional) # REDIS_URL=redis://localhost:6379 +PRECHECK_ALLOW_CACHE_TTL_SECONDS=60 # Public Base URL (for cloud mode) # PUBLIC_BASE=https://your-domain.com diff --git a/tests/test_precheck_allow_cache.py b/tests/test_precheck_allow_cache.py new file mode 100644 index 0000000..37dea1c --- /dev/null +++ b/tests/test_precheck_allow_cache.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 GovernsAI. All rights reserved. +"""HTTP-level tests for cached allow decisions on /api/v1/precheck.""" + +import pytest + +from app.decision_cache import allow_decision_cache +from app.rate_limit import rate_limiter + +PRECHECK_URL = "/api/v1/precheck" +VALID_PAYLOAD = { + "tool": "model.chat", + "scope": "net.external", + "raw_text": "Cache this exact message.", + "policy_config": {"version": "policy-v1"}, +} + + +@pytest.fixture(autouse=True) +def _reset_runtime_state(): + allow_decision_cache.clear() + rate_limiter.clear() + yield + allow_decision_cache.clear() + rate_limiter.clear() + + +@pytest.fixture(autouse=True) +def _stub_side_effects(monkeypatch): + async def _noop_emit_event(*_args, **_kwargs): + return None + + monkeypatch.setattr("app.api.emit_event", _noop_emit_event) + monkeypatch.setattr("app.api.audit_log", lambda *_args, **_kwargs: None) + + +def test_identical_allow_request_hits_cache_within_ttl( + test_client, active_api_key, monkeypatch +): + calls = {"count": 0} + + def fake_evaluate_with_payload_policy(**_kwargs): + calls["count"] += 1 + return { + "decision": "allow", + "raw_text_out": VALID_PAYLOAD["raw_text"], + "reasons": ["policy.allow"], + "policy_id": "tool-access", + "ts": 1000, + } + + monkeypatch.setattr( + "app.api.evaluate_with_payload_policy", fake_evaluate_with_payload_policy + ) + + headers = {"X-Governs-Key": active_api_key.key} + first = test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) + second = test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) + + assert first.status_code == 200 + assert second.status_code == 200 + assert first.headers["x-cache"] == "MISS" + assert second.headers["x-cache"] == "HIT" + assert first.json() == second.json() + assert calls["count"] == 1 + + +def test_transform_decisions_are_never_cached(test_client, active_api_key, monkeypatch): + calls = {"count": 0} + + def fake_evaluate_with_payload_policy(**_kwargs): + calls["count"] += 1 + return { + "decision": "transform", + "raw_text_out": "[REDACTED]", + "reasons": ["pii.redacted:PII:email_address"], + "policy_id": "tool-access", + "ts": 1000 + calls["count"], + } + + monkeypatch.setattr( + "app.api.evaluate_with_payload_policy", fake_evaluate_with_payload_policy + ) + + headers = {"X-Governs-Key": active_api_key.key} + first = test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) + second = test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) + + assert first.status_code == 200 + assert second.status_code == 200 + assert first.headers["x-cache"] == "MISS" + assert second.headers["x-cache"] == "MISS" + assert calls["count"] == 2 + + +def test_allow_cache_expires_after_ttl(test_client, active_api_key, monkeypatch): + calls = {"count": 0} + now = [1000.0] + + def fake_evaluate_with_payload_policy(**_kwargs): + calls["count"] += 1 + return { + "decision": "allow", + "raw_text_out": VALID_PAYLOAD["raw_text"], + "reasons": ["policy.allow"], + "policy_id": "tool-access", + "ts": 1000 + calls["count"], + } + + monkeypatch.setattr( + "app.api.evaluate_with_payload_policy", fake_evaluate_with_payload_policy + ) + monkeypatch.setattr("app.decision_cache.time.time", lambda: now[0]) + + headers = {"X-Governs-Key": active_api_key.key} + first = test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) + second = test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) + now[0] = 1061.0 + third = test_client.post(PRECHECK_URL, json=VALID_PAYLOAD, headers=headers) + + assert first.status_code == 200 + assert second.status_code == 200 + assert third.status_code == 200 + assert first.headers["x-cache"] == "MISS" + assert second.headers["x-cache"] == "HIT" + assert third.headers["x-cache"] == "MISS" + assert calls["count"] == 2 From c2fd436a15740db3f52bbefdbcfb9e51741198bb Mon Sep 17 00:00:00 2001 From: Shaishav Pidadi Date: Fri, 24 Apr 2026 10:03:43 -0400 Subject: [PATCH 2/2] fix(precheck): harden decision cache redis typing Refs: 77df8824-f4b8-4160-8570-9a9ac16d1fa1 --- app/decision_cache.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/app/decision_cache.py b/app/decision_cache.py index 93d2d1f..7f0eb16 100644 --- a/app/decision_cache.py +++ b/app/decision_cache.py @@ -2,23 +2,35 @@ import logging import threading import time -from typing import Dict, Optional, Tuple +from types import ModuleType +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from .settings import settings logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from redis import Redis +else: # pragma: no cover - runtime-only fallback for optional typing + Redis = Any + +_redis_module: Optional[ModuleType] = None +imported_redis_module: Optional[ModuleType] + try: - import redis + import redis as imported_redis_module except Exception: # pragma: no cover - exercised in environments without redis package - redis = None + imported_redis_module = None + +_redis_module = imported_redis_module +redis: Optional[ModuleType] = _redis_module class AllowDecisionCache: """Redis-first cache for cacheable allow decisions.""" def __init__(self, redis_url: Optional[str] = None, ttl_seconds: int = 60): - self.redis_client = None + self.redis_client: Optional[Redis] = None self.ttl_seconds = max(0, int(ttl_seconds)) self._local_lock = threading.Lock() self._local_store: Dict[str, Tuple[float, str]] = {} @@ -79,12 +91,28 @@ def clear(self) -> None: self._last_cleanup = 0.0 def _get_redis(self, key: str) -> Optional[Dict]: - payload = self.redis_client.get(key) + client = self.redis_client + if client is None: + return None + + payload = client.get(key) if payload is None: return None + if isinstance(payload, bytes): - payload = payload.decode("utf-8") - return json.loads(payload) + payload_text = payload.decode("utf-8") + elif isinstance(payload, bytearray): + payload_text = bytes(payload).decode("utf-8") + elif isinstance(payload, str): + payload_text = payload + else: + logger.warning( + "Unexpected allow-decision cache payload type from Redis: %s", + type(payload).__name__, + ) + return None + + return json.loads(payload_text) def _get_local(self, key: str) -> Optional[Dict]: current_time = time.time()