Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 180 additions & 25 deletions src/governs_ai/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,52 @@
import asyncio
import random
import time
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import httpx

from .memory import MemoryClient
from .types import BudgetResult, PrecheckResult

_DEFAULT_RETRY_INITIAL_DELAY = 1.0
_DEFAULT_RETRY_BACKOFF_FACTOR = 2.0
_DEFAULT_RETRY_MAX_DELAY = 30.0
# Per-call precheck overrides consumed via **kwargs; anything else is forwarded
# into the request body for forward compatibility.
_PRECHECK_CONFIG_KEYS = frozenset(
{
"timeout",
"max_retries",
"retry_initial_delay",
"retry_backoff_factor",
"retry_max_delay",
"jitter",
"scope",
"user_id",
"corr_id",
"tags",
"payload",
}
)


def _is_retryable_status(status_code: int) -> bool:
return status_code >= 500 or status_code == 429


def _compute_retry_delay(
attempt: int,
initial: float,
factor: float,
max_delay: float,
jitter: bool,
) -> float:
"""Exponential backoff with optional jitter, capped at ``max_delay``."""
delay = min(initial * (factor**attempt), max_delay)
if jitter:
delay *= random.uniform(0.5, 1.5)
return min(delay, max_delay)


class GovernsAIError(Exception):
"""Base error for GovernsAI SDK"""
Expand Down Expand Up @@ -65,13 +105,50 @@ def __repr__(self):
return f"<GovernsAIClient(base_url='{self.base_url}', org_id='{self.org_id}')>"

def _get_payload(
self, content: str, tool: str, org_id: Optional[str]
self,
content: str,
tool: str,
org_id: Optional[str],
*,
scope: str = "net.external",
user_id: Optional[str] = None,
corr_id: Optional[str] = None,
tags: Optional[List[str]] = None,
extra_payload: Optional[Dict[str, Any]] = None,
extras: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
return {
payload: Dict[str, Any] = {
"tool": tool,
"raw_text": content,
"org_id": org_id or self.org_id,
"scope": "net.external",
"scope": scope,
}
if user_id is not None:
payload["user_id"] = user_id
if corr_id is not None:
payload["corr_id"] = corr_id
if tags is not None:
payload["tags"] = tags
if extra_payload is not None:
payload["payload"] = extra_payload
if extras:
# Unknown kwargs pass through so the SDK tolerates server schema growth.
payload.update(extras)
return payload

def _resolve_retry_config(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Pop retry/config kwargs, falling back to client defaults."""
return {
"timeout": kwargs.pop("timeout", self.timeout),
"max_retries": kwargs.pop("max_retries", self.max_retries),
"retry_initial_delay": kwargs.pop(
"retry_initial_delay", _DEFAULT_RETRY_INITIAL_DELAY
),
"retry_backoff_factor": kwargs.pop(
"retry_backoff_factor", _DEFAULT_RETRY_BACKOFF_FACTOR
),
"retry_max_delay": kwargs.pop("retry_max_delay", _DEFAULT_RETRY_MAX_DELAY),
"jitter": kwargs.pop("jitter", False),
}

def _parse_response(
Expand Down Expand Up @@ -109,29 +186,73 @@ def precheck(
content: str,
tool: str,
org_id: Optional[str] = None,
**kwargs: Any,
) -> PrecheckResult:
"""Check a content/tool request for governance compliance.

Args:
content: Raw user-facing text to evaluate.
tool: Tool identifier (e.g. ``"model.chat"``).
org_id: Organization owning the request. Falls back to the
client-level ``org_id``.

Keyword Args:
timeout: Per-call timeout override (seconds).
max_retries: Override the client-level retry ceiling.
retry_initial_delay: First-attempt backoff delay in seconds
(default 1.0).
retry_backoff_factor: Multiplier applied per retry (default 2.0).
retry_max_delay: Upper bound on a single backoff sleep (default 30.0).
jitter: Multiply backoff by ``uniform(0.5, 1.5)`` when True.
scope: Override the ``scope`` field (default ``"net.external"``).
user_id, corr_id, tags, payload: Optional request body fields.
Any other kwargs are forwarded into the request body verbatim.

Returns:
:class:`PrecheckResult` with decision, redacted_content, reasons,
and client-measured ``latency_ms``.

Raises:
PrecheckError: On non-retryable 4xx or exhausted retries.
"""
Check a request for governance compliance.
"""
payload = self._get_payload(content, tool, org_id)
retry = self._resolve_retry_config(kwargs)
payload = self._get_payload(
content,
tool,
org_id,
scope=kwargs.pop("scope", "net.external"),
user_id=kwargs.pop("user_id", None),
corr_id=kwargs.pop("corr_id", None),
tags=kwargs.pop("tags", None),
extra_payload=kwargs.pop("payload", None),
extras={k: v for k, v in kwargs.items() if k not in _PRECHECK_CONFIG_KEYS},
)
start_time = time.time()

last_error_msg = "Unknown error"
for attempt in range(self.max_retries + 1):
for attempt in range(retry["max_retries"] + 1):
try:
with httpx.Client(timeout=self.timeout) as client:
with httpx.Client(timeout=retry["timeout"]) as client:
response = client.post(
f"{self.base_url}/api/v1/precheck",
json=payload,
headers=self.headers,
)

if response.status_code >= 500 or response.status_code == 429:
if _is_retryable_status(response.status_code):
last_error_msg = (
f"HTTP {response.status_code} {response.reason_phrase}"
)
if attempt < self.max_retries:
time.sleep(2**attempt)
if attempt < retry["max_retries"]:
time.sleep(
_compute_retry_delay(
attempt,
retry["retry_initial_delay"],
retry["retry_backoff_factor"],
retry["retry_max_delay"],
retry["jitter"],
)
)
continue
else:
break
Expand All @@ -140,8 +261,16 @@ def precheck(
return self._parse_response(response, latency_ms)
except (httpx.RequestError, httpx.TimeoutException) as e:
last_error_msg = str(e)
if attempt < self.max_retries:
time.sleep(2**attempt)
if attempt < retry["max_retries"]:
time.sleep(
_compute_retry_delay(
attempt,
retry["retry_initial_delay"],
retry["retry_backoff_factor"],
retry["retry_max_delay"],
retry["jitter"],
)
)
continue

raise PrecheckError(f"Max retries exceeded: {last_error_msg}")
Expand All @@ -151,29 +280,47 @@ async def async_precheck(
content: str,
tool: str,
org_id: Optional[str] = None,
**kwargs: Any,
) -> PrecheckResult:
"""
Async version of precheck.
"""
payload = self._get_payload(content, tool, org_id)
"""Async counterpart of :meth:`precheck` accepting the same kwargs."""
retry = self._resolve_retry_config(kwargs)
payload = self._get_payload(
content,
tool,
org_id,
scope=kwargs.pop("scope", "net.external"),
user_id=kwargs.pop("user_id", None),
corr_id=kwargs.pop("corr_id", None),
tags=kwargs.pop("tags", None),
extra_payload=kwargs.pop("payload", None),
extras={k: v for k, v in kwargs.items() if k not in _PRECHECK_CONFIG_KEYS},
)
start_time = time.time()

last_error_msg = "Unknown error"
for attempt in range(self.max_retries + 1):
for attempt in range(retry["max_retries"] + 1):
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
async with httpx.AsyncClient(timeout=retry["timeout"]) as client:
response = await client.post(
f"{self.base_url}/api/v1/precheck",
json=payload,
headers=self.headers,
)

if response.status_code >= 500 or response.status_code == 429:
if _is_retryable_status(response.status_code):
last_error_msg = (
f"HTTP {response.status_code} {response.reason_phrase}"
)
if attempt < self.max_retries:
await asyncio.sleep(2**attempt)
if attempt < retry["max_retries"]:
await asyncio.sleep(
_compute_retry_delay(
attempt,
retry["retry_initial_delay"],
retry["retry_backoff_factor"],
retry["retry_max_delay"],
retry["jitter"],
)
)
continue
else:
break
Expand All @@ -182,8 +329,16 @@ async def async_precheck(
return self._parse_response(response, latency_ms)
except (httpx.RequestError, httpx.TimeoutException) as e:
last_error_msg = str(e)
if attempt < self.max_retries:
await asyncio.sleep(2**attempt)
if attempt < retry["max_retries"]:
await asyncio.sleep(
_compute_retry_delay(
attempt,
retry["retry_initial_delay"],
retry["retry_backoff_factor"],
retry["retry_max_delay"],
retry["jitter"],
)
)
continue

raise PrecheckError(f"Max retries exceeded: {last_error_msg}")
Expand Down
Loading
Loading