Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions src/governs_ai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,32 +192,69 @@ async def async_precheck(
# 1.4c — record_usage()
# ------------------------------------------------------------------

_USAGE_KWARG_MAP = {
"output_tokens": "outputTokens",
"tokens_out": "outputTokens",
"provider": "provider",
"cost": "cost",
"tool": "toolId",
"tool_id": "toolId",
"correlation_id": "correlationId",
"metadata": "metadata",
}

def _build_usage_payload(
self,
org_id: str,
user_id: str,
tokens: int,
model: str,
extras: Dict[str, Any],
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"orgId": org_id or self.org_id,
"userId": user_id,
"inputTokens": tokens,
"outputTokens": 0,
"model": model,
"provider": "openai",
}
for key, value in extras.items():
if value is None:
continue
payload_key = self._USAGE_KWARG_MAP.get(key, key)
payload[payload_key] = value
return payload

def record_usage(
self,
org_id: str,
user_id: str,
tokens: int,
model: str,
*,
provider: str = "openai",
**kwargs: Any,
) -> None:
"""Record token usage for a model request.

Args:
org_id: Organization ID (falls back to ``client.org_id``).
user_id: End-user identifier.
tokens: Input token count for the request.
model: Model identifier (e.g., ``"gpt-4o-mini"``).
**kwargs: Optional extras forwarded to the platform API. Recognised
keys: ``output_tokens``, ``provider``, ``cost``, ``tool_id``,
``correlation_id``, ``metadata``. Unknown keys are passed
through unchanged.

Example::

client.record_usage(
org_id="org-1", user_id="user-123",
tokens=180, model="gpt-4o-mini",
output_tokens=42, cost=0.0012,
)
"""
payload: Dict[str, Any] = {
"orgId": org_id or self.org_id,
"userId": user_id,
"inputTokens": tokens,
"outputTokens": 0,
"model": model,
"provider": provider,
}
payload = self._build_usage_payload(org_id, user_id, tokens, model, kwargs)
with httpx.Client(timeout=self.timeout) as http:
resp = http.post(
f"{self.base_url}/api/v1/usage",
Expand All @@ -226,7 +263,8 @@ def record_usage(
)
if resp.status_code >= 400:
raise GovernsAIError(
f"record_usage failed with HTTP {resp.status_code}: {resp.text}",
f"record_usage failed with HTTP {resp.status_code}: {resp.text} — "
f"verify org_id/user_id and that the API key has usage write scope",
status_code=resp.status_code,
)

Expand All @@ -236,18 +274,13 @@ async def async_record_usage(
user_id: str,
tokens: int,
model: str,
*,
provider: str = "openai",
**kwargs: Any,
) -> None:
"""Async variant of :meth:`record_usage`."""
payload: Dict[str, Any] = {
"orgId": org_id or self.org_id,
"userId": user_id,
"inputTokens": tokens,
"outputTokens": 0,
"model": model,
"provider": provider,
}
"""Async variant of :meth:`record_usage`.

Accepts the same arguments and kwargs as the sync form.
"""
payload = self._build_usage_payload(org_id, user_id, tokens, model, kwargs)
async with httpx.AsyncClient(timeout=self.timeout) as http:
resp = await http.post(
f"{self.base_url}/api/v1/usage",
Expand All @@ -256,7 +289,8 @@ async def async_record_usage(
)
if resp.status_code >= 400:
raise GovernsAIError(
f"record_usage failed with HTTP {resp.status_code}: {resp.text}",
f"record_usage failed with HTTP {resp.status_code}: {resp.text} — "
f"verify org_id/user_id and that the API key has usage write scope",
status_code=resp.status_code,
)

Expand Down
41 changes: 41 additions & 0 deletions tests/test_record_usage_budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,47 @@ async def test_async_record_usage_sends_correct_payload(client):
assert body["inputTokens"] == 50


@respx.mock
def test_record_usage_succeeds_on_204_no_content(client):
"""Platform API returns 204 No Content — SDK must not raise or try to parse body."""
respx.post(f"{BASE}/api/v1/usage").mock(return_value=httpx.Response(204))
client.record_usage(org_id="org-1", user_id="user-123", tokens=10, model="gpt-4o")


@respx.mock
async def test_async_record_usage_succeeds_on_204_no_content(client):
respx.post(f"{BASE}/api/v1/usage").mock(return_value=httpx.Response(204))
await client.async_record_usage(
org_id="org-1", user_id="user-123", tokens=10, model="gpt-4o"
)


@respx.mock
def test_record_usage_forwards_kwargs_to_payload(client):
"""Recognised kwargs are mapped to camelCase platform fields; unknown kwargs pass through."""
route = respx.post(f"{BASE}/api/v1/usage").mock(return_value=httpx.Response(204))
client.record_usage(
org_id="org-1",
user_id="user-123",
tokens=100,
model="gpt-4o",
output_tokens=50,
cost=0.0012,
provider="anthropic",
tool_id="web_search",
correlation_id="req-abc",
metadata={"session": "s1"},
)
body = json.loads(route.calls[0].request.content)
assert body["inputTokens"] == 100
assert body["outputTokens"] == 50
assert body["cost"] == 0.0012
assert body["provider"] == "anthropic"
assert body["toolId"] == "web_search"
assert body["correlationId"] == "req-abc"
assert body["metadata"] == {"session": "s1"}


@respx.mock
def test_budget_check_allowed(client):
respx.get(f"{BASE}/api/v1/budget/context").mock(
Expand Down
Loading