From a2ca25009eb2e00c76aecddf3fbbc598d5b76fce Mon Sep 17 00:00:00 2001 From: Shaishav Pidadi Date: Fri, 17 Apr 2026 17:12:44 -0400 Subject: [PATCH] feat(python-sdk): add record_usage() and budget_check() (GOV-38) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - client.record_usage(org_id, user_id, tokens, model) → None - client.async_record_usage() — async variant - client.budget_check(org_id, user_id, estimated_tokens) → BudgetResult - client.async_budget_check() — async variant - BudgetResult: allowed, remaining_tokens, limit, warning_threshold_hit, reason - 5 unit tests: correct payload, allowed/denied, warning threshold (< 10%) Refs: GOV-38 --- src/governs_ai/__init__.py | 10 +- src/governs_ai/client.py | 157 +++++++++++++++++++++++++++++- src/governs_ai/types.py | 20 ++++ tests/test_record_usage_budget.py | 82 ++++++++++++++++ 4 files changed, 266 insertions(+), 3 deletions(-) create mode 100644 tests/test_record_usage_budget.py diff --git a/src/governs_ai/__init__.py b/src/governs_ai/__init__.py index 11629b2..a5299e6 100644 --- a/src/governs_ai/__init__.py +++ b/src/governs_ai/__init__.py @@ -1,4 +1,10 @@ from .client import GovernsAIClient, GovernsAIError, PrecheckError -from .types import PrecheckResult +from .types import BudgetResult, PrecheckResult -__all__ = ["GovernsAIClient", "GovernsAIError", "PrecheckError", "PrecheckResult"] +__all__ = [ + "GovernsAIClient", + "GovernsAIError", + "PrecheckError", + "PrecheckResult", + "BudgetResult", +] diff --git a/src/governs_ai/client.py b/src/governs_ai/client.py index f216a61..1831e28 100644 --- a/src/governs_ai/client.py +++ b/src/governs_ai/client.py @@ -4,7 +4,7 @@ import httpx -from .types import PrecheckResult +from .types import BudgetResult, PrecheckResult class GovernsAIError(Exception): @@ -180,3 +180,158 @@ async def async_precheck( continue raise PrecheckError(f"Max retries exceeded: {last_error_msg}") + + # ------------------------------------------------------------------ + # 1.4c — record_usage() + # ------------------------------------------------------------------ + + def record_usage( + self, + org_id: str, + user_id: str, + tokens: int, + model: str, + *, + provider: str = "openai", + ) -> None: + """Record token usage for a model request. + + Example:: + + client.record_usage( + org_id="org-1", user_id="user-123", + tokens=180, model="gpt-4o-mini", + ) + """ + payload: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "inputTokens": tokens, + "outputTokens": 0, + "model": model, + "provider": provider, + } + with httpx.Client(timeout=self.timeout) as http: + resp = http.post( + f"{self.base_url}/api/v1/usage", + json=payload, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"record_usage failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + + async def async_record_usage( + self, + org_id: str, + user_id: str, + tokens: int, + model: str, + *, + provider: str = "openai", + ) -> None: + """Async variant of :meth:`record_usage`.""" + payload: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "inputTokens": tokens, + "outputTokens": 0, + "model": model, + "provider": provider, + } + async with httpx.AsyncClient(timeout=self.timeout) as http: + resp = await http.post( + f"{self.base_url}/api/v1/usage", + json=payload, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"record_usage failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + + # ------------------------------------------------------------------ + # 1.4c — budget_check() + # ------------------------------------------------------------------ + + def budget_check( + self, + org_id: str, + user_id: str, + estimated_tokens: int = 0, + ) -> BudgetResult: + """Check whether the user/org is within budget. + + Example:: + + budget = client.budget_check(org_id="org-1", user_id="u1", estimated_tokens=500) + if not budget.allowed: + raise RuntimeError("Budget exceeded") + """ + params: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "estimatedTokens": estimated_tokens, + } + with httpx.Client(timeout=self.timeout) as http: + resp = http.get( + f"{self.base_url}/api/v1/budget/context", + params=params, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"budget_check failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + data = resp.json() + limit = data.get("limit", data.get("monthly_limit", 0)) + remaining = data.get("remaining_tokens", data.get("remaining", limit)) + allowed = data.get("allowed", remaining > 0) + warning_threshold_hit = limit > 0 and (remaining / limit) < 0.10 + return BudgetResult( + allowed=allowed, + remaining_tokens=int(remaining), + limit=int(limit), + warning_threshold_hit=warning_threshold_hit, + reason=data.get("reason", ""), + ) + + async def async_budget_check( + self, + org_id: str, + user_id: str, + estimated_tokens: int = 0, + ) -> BudgetResult: + """Async variant of :meth:`budget_check`.""" + params: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "estimatedTokens": estimated_tokens, + } + async with httpx.AsyncClient(timeout=self.timeout) as http: + resp = await http.get( + f"{self.base_url}/api/v1/budget/context", + params=params, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"budget_check failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + data = resp.json() + limit = data.get("limit", data.get("monthly_limit", 0)) + remaining = data.get("remaining_tokens", data.get("remaining", limit)) + allowed = data.get("allowed", remaining > 0) + warning_threshold_hit = limit > 0 and (remaining / limit) < 0.10 + return BudgetResult( + allowed=allowed, + remaining_tokens=int(remaining), + limit=int(limit), + warning_threshold_hit=warning_threshold_hit, + reason=data.get("reason", ""), + ) diff --git a/src/governs_ai/types.py b/src/governs_ai/types.py index 366c918..41113e3 100644 --- a/src/governs_ai/types.py +++ b/src/governs_ai/types.py @@ -8,3 +8,23 @@ class PrecheckResult: redacted_content: Optional[str] = None reasons: List[str] = field(default_factory=list) latency_ms: float = 0.0 + + +@dataclass +class BudgetResult: + """Result of a budget_check call. + + Example:: + + budget = client.budget_check(org_id="org-1", user_id="user-1", estimated_tokens=500) + if not budget.allowed: + raise RuntimeError("Budget exceeded") + if budget.warning_threshold_hit: + logger.warning("Less than 10% budget remaining") + """ + + allowed: bool + remaining_tokens: int + limit: int + warning_threshold_hit: bool + reason: str = "" diff --git a/tests/test_record_usage_budget.py b/tests/test_record_usage_budget.py new file mode 100644 index 0000000..c269ceb --- /dev/null +++ b/tests/test_record_usage_budget.py @@ -0,0 +1,82 @@ +"""Unit tests for record_usage() and budget_check().""" + +import json +import pytest +import respx +import httpx + +from governs_ai import GovernsAIClient, BudgetResult + +BASE = "https://api.governs.ai" + + +@pytest.fixture +def client(): + return GovernsAIClient(api_key="test-key", org_id="org-test") + + +@respx.mock +def test_record_usage_sends_correct_payload(client): + route = respx.post(f"{BASE}/api/v1/usage").mock( + return_value=httpx.Response(200, json={"accepted": True}) + ) + client.record_usage(org_id="org-1", user_id="user-123", tokens=100, model="gpt-4o") + body = json.loads(route.calls[0].request.content) + assert body["orgId"] == "org-1" + assert body["userId"] == "user-123" + assert body["inputTokens"] == 100 + assert body["model"] == "gpt-4o" + + +@respx.mock +async def test_async_record_usage_sends_correct_payload(client): + route = respx.post(f"{BASE}/api/v1/usage").mock( + return_value=httpx.Response(200, json={"accepted": True}) + ) + await client.async_record_usage( + org_id="org-1", user_id="user-123", tokens=50, model="gpt-4o-mini" + ) + body = json.loads(route.calls[0].request.content) + assert body["userId"] == "user-123" + assert body["inputTokens"] == 50 + + +@respx.mock +def test_budget_check_allowed(client): + respx.get(f"{BASE}/api/v1/budget/context").mock( + return_value=httpx.Response( + 200, + json={"allowed": True, "remaining_tokens": 9000, "limit": 10000, "reason": ""}, + ) + ) + result = client.budget_check(org_id="org-1", user_id="user-1") + assert isinstance(result, BudgetResult) + assert result.allowed is True + assert result.warning_threshold_hit is False + + +@respx.mock +def test_budget_check_denied_when_over_budget(client): + respx.get(f"{BASE}/api/v1/budget/context").mock( + return_value=httpx.Response( + 200, + json={"allowed": False, "remaining_tokens": 0, "limit": 10000, "reason": "over_budget"}, + ) + ) + result = client.budget_check(org_id="org-1", user_id="user-1") + assert result.allowed is False + assert result.reason == "over_budget" + + +@respx.mock +def test_budget_check_warning_threshold_hit(client): + """warning_threshold_hit=True when remaining < 10% of limit.""" + respx.get(f"{BASE}/api/v1/budget/context").mock( + return_value=httpx.Response( + 200, + json={"allowed": True, "remaining_tokens": 500, "limit": 10000, "reason": ""}, + ) + ) + result = client.budget_check(org_id="org-1", user_id="user-1") + assert result.allowed is True + assert result.warning_threshold_hit is True