diff --git a/src/governs_ai/client.py b/src/governs_ai/client.py index a300b5b..529591a 100644 --- a/src/governs_ai/client.py +++ b/src/governs_ai/client.py @@ -1,12 +1,52 @@ import asyncio +import random import time -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import httpx from .memory import MemoryClient from .types import BudgetResult, PrecheckResult +_DEFAULT_RETRY_INITIAL_DELAY = 1.0 +_DEFAULT_RETRY_BACKOFF_FACTOR = 2.0 +_DEFAULT_RETRY_MAX_DELAY = 30.0 +# Per-call precheck overrides consumed via **kwargs; anything else is forwarded +# into the request body for forward compatibility. +_PRECHECK_CONFIG_KEYS = frozenset( + { + "timeout", + "max_retries", + "retry_initial_delay", + "retry_backoff_factor", + "retry_max_delay", + "jitter", + "scope", + "user_id", + "corr_id", + "tags", + "payload", + } +) + + +def _is_retryable_status(status_code: int) -> bool: + return status_code >= 500 or status_code == 429 + + +def _compute_retry_delay( + attempt: int, + initial: float, + factor: float, + max_delay: float, + jitter: bool, +) -> float: + """Exponential backoff with optional jitter, capped at ``max_delay``.""" + delay = min(initial * (factor**attempt), max_delay) + if jitter: + delay *= random.uniform(0.5, 1.5) + return min(delay, max_delay) + class GovernsAIError(Exception): """Base error for GovernsAI SDK""" @@ -65,13 +105,50 @@ def __repr__(self): return f"" def _get_payload( - self, content: str, tool: str, org_id: Optional[str] + self, + content: str, + tool: str, + org_id: Optional[str], + *, + scope: str = "net.external", + user_id: Optional[str] = None, + corr_id: Optional[str] = None, + tags: Optional[List[str]] = None, + extra_payload: Optional[Dict[str, Any]] = None, + extras: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - return { + payload: Dict[str, Any] = { "tool": tool, "raw_text": content, "org_id": org_id or self.org_id, - "scope": "net.external", + "scope": scope, + } + if user_id is not None: + payload["user_id"] = user_id + if corr_id is not None: + payload["corr_id"] = corr_id + if tags is not None: + payload["tags"] = tags + if extra_payload is not None: + payload["payload"] = extra_payload + if extras: + # Unknown kwargs pass through so the SDK tolerates server schema growth. + payload.update(extras) + return payload + + def _resolve_retry_config(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Pop retry/config kwargs, falling back to client defaults.""" + return { + "timeout": kwargs.pop("timeout", self.timeout), + "max_retries": kwargs.pop("max_retries", self.max_retries), + "retry_initial_delay": kwargs.pop( + "retry_initial_delay", _DEFAULT_RETRY_INITIAL_DELAY + ), + "retry_backoff_factor": kwargs.pop( + "retry_backoff_factor", _DEFAULT_RETRY_BACKOFF_FACTOR + ), + "retry_max_delay": kwargs.pop("retry_max_delay", _DEFAULT_RETRY_MAX_DELAY), + "jitter": kwargs.pop("jitter", False), } def _parse_response( @@ -109,29 +186,73 @@ def precheck( content: str, tool: str, org_id: Optional[str] = None, + **kwargs: Any, ) -> PrecheckResult: + """Check a content/tool request for governance compliance. + + Args: + content: Raw user-facing text to evaluate. + tool: Tool identifier (e.g. ``"model.chat"``). + org_id: Organization owning the request. Falls back to the + client-level ``org_id``. + + Keyword Args: + timeout: Per-call timeout override (seconds). + max_retries: Override the client-level retry ceiling. + retry_initial_delay: First-attempt backoff delay in seconds + (default 1.0). + retry_backoff_factor: Multiplier applied per retry (default 2.0). + retry_max_delay: Upper bound on a single backoff sleep (default 30.0). + jitter: Multiply backoff by ``uniform(0.5, 1.5)`` when True. + scope: Override the ``scope`` field (default ``"net.external"``). + user_id, corr_id, tags, payload: Optional request body fields. + Any other kwargs are forwarded into the request body verbatim. + + Returns: + :class:`PrecheckResult` with decision, redacted_content, reasons, + and client-measured ``latency_ms``. + + Raises: + PrecheckError: On non-retryable 4xx or exhausted retries. """ - Check a request for governance compliance. - """ - payload = self._get_payload(content, tool, org_id) + retry = self._resolve_retry_config(kwargs) + payload = self._get_payload( + content, + tool, + org_id, + scope=kwargs.pop("scope", "net.external"), + user_id=kwargs.pop("user_id", None), + corr_id=kwargs.pop("corr_id", None), + tags=kwargs.pop("tags", None), + extra_payload=kwargs.pop("payload", None), + extras={k: v for k, v in kwargs.items() if k not in _PRECHECK_CONFIG_KEYS}, + ) start_time = time.time() last_error_msg = "Unknown error" - for attempt in range(self.max_retries + 1): + for attempt in range(retry["max_retries"] + 1): try: - with httpx.Client(timeout=self.timeout) as client: + with httpx.Client(timeout=retry["timeout"]) as client: response = client.post( f"{self.base_url}/api/v1/precheck", json=payload, headers=self.headers, ) - if response.status_code >= 500 or response.status_code == 429: + if _is_retryable_status(response.status_code): last_error_msg = ( f"HTTP {response.status_code} {response.reason_phrase}" ) - if attempt < self.max_retries: - time.sleep(2**attempt) + if attempt < retry["max_retries"]: + time.sleep( + _compute_retry_delay( + attempt, + retry["retry_initial_delay"], + retry["retry_backoff_factor"], + retry["retry_max_delay"], + retry["jitter"], + ) + ) continue else: break @@ -140,8 +261,16 @@ def precheck( return self._parse_response(response, latency_ms) except (httpx.RequestError, httpx.TimeoutException) as e: last_error_msg = str(e) - if attempt < self.max_retries: - time.sleep(2**attempt) + if attempt < retry["max_retries"]: + time.sleep( + _compute_retry_delay( + attempt, + retry["retry_initial_delay"], + retry["retry_backoff_factor"], + retry["retry_max_delay"], + retry["jitter"], + ) + ) continue raise PrecheckError(f"Max retries exceeded: {last_error_msg}") @@ -151,29 +280,47 @@ async def async_precheck( content: str, tool: str, org_id: Optional[str] = None, + **kwargs: Any, ) -> PrecheckResult: - """ - Async version of precheck. - """ - payload = self._get_payload(content, tool, org_id) + """Async counterpart of :meth:`precheck` accepting the same kwargs.""" + retry = self._resolve_retry_config(kwargs) + payload = self._get_payload( + content, + tool, + org_id, + scope=kwargs.pop("scope", "net.external"), + user_id=kwargs.pop("user_id", None), + corr_id=kwargs.pop("corr_id", None), + tags=kwargs.pop("tags", None), + extra_payload=kwargs.pop("payload", None), + extras={k: v for k, v in kwargs.items() if k not in _PRECHECK_CONFIG_KEYS}, + ) start_time = time.time() last_error_msg = "Unknown error" - for attempt in range(self.max_retries + 1): + for attempt in range(retry["max_retries"] + 1): try: - async with httpx.AsyncClient(timeout=self.timeout) as client: + async with httpx.AsyncClient(timeout=retry["timeout"]) as client: response = await client.post( f"{self.base_url}/api/v1/precheck", json=payload, headers=self.headers, ) - if response.status_code >= 500 or response.status_code == 429: + if _is_retryable_status(response.status_code): last_error_msg = ( f"HTTP {response.status_code} {response.reason_phrase}" ) - if attempt < self.max_retries: - await asyncio.sleep(2**attempt) + if attempt < retry["max_retries"]: + await asyncio.sleep( + _compute_retry_delay( + attempt, + retry["retry_initial_delay"], + retry["retry_backoff_factor"], + retry["retry_max_delay"], + retry["jitter"], + ) + ) continue else: break @@ -182,8 +329,16 @@ async def async_precheck( return self._parse_response(response, latency_ms) except (httpx.RequestError, httpx.TimeoutException) as e: last_error_msg = str(e) - if attempt < self.max_retries: - await asyncio.sleep(2**attempt) + if attempt < retry["max_retries"]: + await asyncio.sleep( + _compute_retry_delay( + attempt, + retry["retry_initial_delay"], + retry["retry_backoff_factor"], + retry["retry_max_delay"], + retry["jitter"], + ) + ) continue raise PrecheckError(f"Max retries exceeded: {last_error_msg}") diff --git a/tests/test_client.py b/tests/test_client.py index 7b729bc..adbc8d9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -95,3 +95,166 @@ def test_precheck_max_retries_exceeded(client, httpx_mock): assert "Max retries exceeded" in str(excinfo.value) assert len(httpx_mock.get_requests()) == 4 + + +def test_precheck_result_fields_populated_from_response(client, httpx_mock): + """PrecheckResult fields should all be populated from the response JSON.""" + httpx_mock.add_response( + method="POST", + url="https://api.governs.ai/api/v1/precheck", + json={ + "decision": "redact", + "redacted_content": "Hello [REDACTED]", + "reasons": ["policy.pii", "policy.secret"], + }, + status_code=200, + ) + + result = client.precheck(content="Hello 123-45-6789", tool="model.chat") + + assert result.decision == "redact" + assert result.redacted_content == "Hello [REDACTED]" + assert result.reasons == ["policy.pii", "policy.secret"] + assert result.latency_ms > 0 + + +def test_precheck_forwards_kwargs_into_request_body(client, httpx_mock): + """Extra kwargs should be forwarded into the precheck request body.""" + httpx_mock.add_response( + method="POST", + url="https://api.governs.ai/api/v1/precheck", + json={"decision": "allow", "reasons": []}, + status_code=200, + ) + + client.precheck( + content="Hi", + tool="model.chat", + org_id="org-override", + user_id="user-7", + corr_id="corr-42", + tags=["unit-test"], + payload={"messages": [{"role": "user", "content": "Hi"}]}, + scope="net.internal", + custom_field="pass-through", + ) + + request = httpx_mock.get_request() + body = json.loads(request.read().decode()) + assert body["tool"] == "model.chat" + assert body["raw_text"] == "Hi" + assert body["org_id"] == "org-override" + assert body["user_id"] == "user-7" + assert body["corr_id"] == "corr-42" + assert body["tags"] == ["unit-test"] + assert body["scope"] == "net.internal" + assert body["payload"] == {"messages": [{"role": "user", "content": "Hi"}]} + # Forward-compat: unknown fields pass through unchanged. + assert body["custom_field"] == "pass-through" + + +def test_precheck_uses_configurable_backoff(client, httpx_mock): + """``retry_initial_delay`` and ``retry_backoff_factor`` should drive sleeps.""" + httpx_mock.add_response(status_code=500) + httpx_mock.add_response(status_code=500) + httpx_mock.add_response(status_code=200, json={"decision": "allow"}) + + with patch("time.sleep") as mock_sleep: + client.precheck( + content="Hello", + tool="test-tool", + retry_initial_delay=0.25, + retry_backoff_factor=3.0, + jitter=False, + ) + + assert mock_sleep.call_count == 2 + delays = [call.args[0] for call in mock_sleep.call_args_list] + # attempt 0 -> 0.25, attempt 1 -> 0.25 * 3 = 0.75 + assert delays[0] == pytest.approx(0.25) + assert delays[1] == pytest.approx(0.75) + + +def test_precheck_max_retries_override_via_kwargs(client, httpx_mock): + """Per-call ``max_retries`` override should limit attempts.""" + httpx_mock.add_response(status_code=500) + httpx_mock.add_response(status_code=500) + + with patch("time.sleep"): + with pytest.raises(PrecheckError): + client.precheck( + content="Hello", + tool="test-tool", + max_retries=1, + ) + + # max_retries=1 → 2 total attempts. + assert len(httpx_mock.get_requests()) == 2 + + +def test_precheck_does_not_retry_on_4xx(client, httpx_mock): + """4xx responses should surface immediately without retry.""" + httpx_mock.add_response( + status_code=400, + json={"error": "bad request: missing tool"}, + ) + + with patch("time.sleep") as mock_sleep: + with pytest.raises(PrecheckError) as excinfo: + client.precheck(content="Hello", tool="test-tool") + + assert len(httpx_mock.get_requests()) == 1 + assert mock_sleep.call_count == 0 + assert excinfo.value.status_code == 400 + assert "bad request" in str(excinfo.value) + + +def test_precheck_timeout_override_via_kwargs(client, httpx_mock): + """Per-call ``timeout`` kwarg should not raise on happy path.""" + httpx_mock.add_response( + method="POST", + url="https://api.governs.ai/api/v1/precheck", + json={"decision": "allow", "reasons": []}, + status_code=200, + ) + + result = client.precheck(content="Hi", tool="test-tool", timeout=5.0) + assert result.decision == "allow" + + +@pytest.mark.asyncio +async def test_async_precheck_result_fields_populated(client, httpx_mock): + """Async variant should populate all PrecheckResult fields.""" + httpx_mock.add_response( + method="POST", + url="https://api.governs.ai/api/v1/precheck", + json={ + "decision": "deny", + "redacted_content": None, + "reasons": ["policy.denied"], + }, + status_code=200, + ) + + result = await client.async_precheck(content="Hello", tool="test-tool") + + assert result.decision == "deny" + assert result.reasons == ["policy.denied"] + assert result.latency_ms > 0 + + +@pytest.mark.asyncio +async def test_async_precheck_max_retries_exceeded(client, httpx_mock): + for _ in range(3): + httpx_mock.add_response(status_code=502) + + with patch("asyncio.sleep"): + with pytest.raises(PrecheckError) as excinfo: + await client.async_precheck( + content="Hello", + tool="test-tool", + max_retries=2, + ) + + assert len(httpx_mock.get_requests()) == 3 + assert "Max retries exceeded" in str(excinfo.value)