Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions governs_ai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PrecheckRequest,
PrecheckResponse,
BudgetContext,
BudgetResult,
UsageRecord,
ConfirmationRequest,
ConfirmationResponse,
Expand Down Expand Up @@ -201,17 +202,59 @@ async def get_budget_context(self, user_id: str) -> BudgetContext:
"""
return await self.budget.get_budget_context(user_id)

async def record_usage(self, usage: Union[UsageRecord, Dict[str, Any]]) -> None:
async def budget_check(
self,
org_id: str,
user_id: str,
estimated_tokens: int,
) -> BudgetResult:
"""Check whether a user is within token budget.

Returns a :class:`BudgetResult` with ``allowed``, ``remaining_tokens``,
``limit``, and ``warning_threshold_hit`` (True when < 10% remaining).

Example::

result = await client.budget_check(
org_id="org-123", user_id="user-456", estimated_tokens=1000
)
if not result.allowed:
raise RuntimeError("Budget exceeded")
"""
Record usage for a user.
response = await self.budget.http_client.get(
"/api/v1/budget/context",
params={"orgId": org_id, "userId": user_id, "estimatedTokens": estimated_tokens},
)
return BudgetResult.from_dict(response.data)

Args:
usage: Usage record or dictionary
async def async_budget_check(
self,
org_id: str,
user_id: str,
estimated_tokens: int,
) -> BudgetResult:
"""Async variant of :meth:`budget_check` (identical — both are coroutines)."""
return await self.budget_check(org_id, user_id, estimated_tokens)

async def record_usage(self, usage: Union[UsageRecord, Dict[str, Any]]) -> None:
"""Record token and cost usage for a model request.

Example::

await client.record_usage(UsageRecord(
user_id="user-123", org_id="org-456", provider="openai",
model="gpt-4o-mini", input_tokens=100, output_tokens=80,
cost=0.0017, cost_type="external",
))
"""
if isinstance(usage, dict):
usage = UsageRecord.from_dict(usage)
await self.budget.record_usage(usage)

async def async_record_usage(self, usage: Union[UsageRecord, Dict[str, Any]]) -> None:
"""Async variant of :meth:`record_usage` (identical — both are coroutines)."""
return await self.record_usage(usage)

async def create_confirmation(
self,
request_type: str,
Expand Down
1 change: 1 addition & 0 deletions governs_ai/exceptions/precheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Precheck-specific exceptions.
"""

from typing import Optional, Dict, Any
from .base import GovernsAIError


Expand Down
3 changes: 2 additions & 1 deletion governs_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from .precheck import PrecheckRequest, PrecheckResponse, Decision
from .budget import BudgetContext, UsageRecord, BudgetStatus
from .budget import BudgetContext, UsageRecord, BudgetStatus, BudgetResult
from .confirmation import ConfirmationRequest, ConfirmationResponse
from .health import HealthStatus
from .context import (
Expand Down Expand Up @@ -46,6 +46,7 @@
"BudgetContext",
"UsageRecord",
"BudgetStatus",
"BudgetResult",
"ConfirmationRequest",
"ConfirmationResponse",
"HealthStatus",
Expand Down
49 changes: 49 additions & 0 deletions governs_ai/models/budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,55 @@ def to_dict(self) -> Dict[str, Any]:
}


@dataclass
class BudgetResult:
"""Result of a budget_check() call.

Example::

result = await client.budget_check(
org_id="org-123", user_id="user-456", estimated_tokens=1000
)
if not result.allowed:
raise RuntimeError("Budget exceeded")
if result.warning_threshold_hit:
logger.warning("Less than 10% budget remaining")
"""
allowed: bool
remaining_tokens: int
limit: int
warning_threshold_hit: bool = False
reason: Optional[str] = None

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BudgetResult":
"""Create from dict; auto-computes warning_threshold_hit when < 10% budget remains."""
limit = int(data.get("limit", data.get("monthly_limit", 0)))
remaining = int(
data.get("remaining_tokens",
data.get("remaining_budget",
data.get("remainingBudget", 0)))
)
warning = limit > 0 and (remaining / limit) < 0.10
return cls(
allowed=data.get("allowed", remaining > 0),
remaining_tokens=remaining,
limit=limit,
warning_threshold_hit=data.get("warningThresholdHit", warning),
reason=data.get("reason"),
)

def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"allowed": self.allowed,
"remainingTokens": self.remaining_tokens,
"limit": self.limit,
"warningThresholdHit": self.warning_threshold_hit,
"reason": self.reason,
}


@dataclass
class BudgetStatus:
"""Budget status for a user."""
Expand Down
196 changes: 196 additions & 0 deletions tests/test_budget_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024 GovernsAI. All rights reserved.
"""Unit tests for record_usage() and budget_check()."""

import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from governs_ai.models.budget import BudgetResult, UsageRecord
from governs_ai.client import GovernsAIClient, GovernsAIConfig


def _make_client() -> GovernsAIClient:
with patch("governs_ai.client.HTTPClient"):
client = GovernsAIClient(
api_key="test-key", org_id="org-123", base_url="https://api.test"
)
return client


class TestBudgetResult:
"""Unit tests for BudgetResult model."""

def test_from_dict_allowed(self):
result = BudgetResult.from_dict({
"allowed": True,
"remaining_tokens": 9000,
"limit": 10000,
})
assert result.allowed is True
assert result.remaining_tokens == 9000
assert result.limit == 10000
assert result.warning_threshold_hit is False

def test_from_dict_denied_when_budget_exceeded(self):
result = BudgetResult.from_dict({
"allowed": False,
"remaining_tokens": 0,
"limit": 10000,
"reason": "over_budget",
})
assert result.allowed is False
assert result.reason == "over_budget"

def test_warning_threshold_hit_when_below_10_percent(self):
result = BudgetResult.from_dict({
"allowed": True,
"remaining_tokens": 500,
"limit": 10000,
})
assert result.warning_threshold_hit is True

def test_warning_threshold_not_hit_at_exactly_10_percent(self):
result = BudgetResult.from_dict({
"allowed": True,
"remaining_tokens": 1000,
"limit": 10000,
})
# 1000/10000 == 0.10, not < 0.10
assert result.warning_threshold_hit is False

def test_from_dict_camelcase_keys(self):
result = BudgetResult.from_dict({
"allowed": True,
"remainingBudget": 8000,
"limit": 10000,
})
assert result.remaining_tokens == 8000


class TestBudgetCheckMethod:
"""Unit tests for GovernsAIClient.budget_check()."""

@pytest.mark.asyncio
async def test_budget_check_allowed(self):
client = _make_client()
mock_response = MagicMock()
mock_response.data = {
"allowed": True,
"remaining_tokens": 9000,
"limit": 10000,
}
client.budget.http_client.get = AsyncMock(return_value=mock_response)

result = await client.budget_check(
org_id="org-123", user_id="user-456", estimated_tokens=500
)

assert isinstance(result, BudgetResult)
assert result.allowed is True
assert result.warning_threshold_hit is False
client.budget.http_client.get.assert_called_once_with(
"/api/v1/budget/context",
params={"orgId": "org-123", "userId": "user-456", "estimatedTokens": 500},
)

@pytest.mark.asyncio
async def test_budget_check_denied_when_over_budget(self):
client = _make_client()
mock_response = MagicMock()
mock_response.data = {
"allowed": False,
"remaining_tokens": 0,
"limit": 10000,
"reason": "over_budget",
}
client.budget.http_client.get = AsyncMock(return_value=mock_response)

result = await client.budget_check(
org_id="org-123", user_id="user-456", estimated_tokens=1000
)
assert result.allowed is False
assert result.reason == "over_budget"

@pytest.mark.asyncio
async def test_budget_check_warning_threshold_hit(self):
"""warning_threshold_hit=True when remaining < 10% of limit."""
client = _make_client()
mock_response = MagicMock()
mock_response.data = {
"allowed": True,
"remaining_tokens": 500,
"limit": 10000,
}
client.budget.http_client.get = AsyncMock(return_value=mock_response)

result = await client.budget_check(
org_id="org-123", user_id="user-456", estimated_tokens=100
)
assert result.allowed is True
assert result.warning_threshold_hit is True

@pytest.mark.asyncio
async def test_async_budget_check_is_alias(self):
client = _make_client()
mock_response = MagicMock()
mock_response.data = {"allowed": True, "remaining_tokens": 5000, "limit": 10000}
client.budget.http_client.get = AsyncMock(return_value=mock_response)

result = await client.async_budget_check(
org_id="org-123", user_id="user-456", estimated_tokens=200
)
assert isinstance(result, BudgetResult)


class TestRecordUsageMethod:
"""Unit tests for GovernsAIClient.record_usage()."""

@pytest.mark.asyncio
async def test_record_usage_correct_payload(self):
client = _make_client()
client.budget.record_usage = AsyncMock()

record = UsageRecord(
user_id="user-123",
org_id="org-456",
provider="openai",
model="gpt-4o-mini",
input_tokens=100,
output_tokens=80,
cost=0.0017,
cost_type="external",
)
await client.record_usage(record)
client.budget.record_usage.assert_called_once_with(record)

@pytest.mark.asyncio
async def test_record_usage_dict_converted_to_usage_record(self):
client = _make_client()
client.budget.record_usage = AsyncMock()

await client.record_usage({
"userId": "user-123",
"orgId": "org-456",
"provider": "openai",
"model": "gpt-4",
"inputTokens": 50,
"outputTokens": 30,
"cost": 0.001,
"costType": "external",
})
call_arg = client.budget.record_usage.call_args[0][0]
assert isinstance(call_arg, UsageRecord)
assert call_arg.user_id == "user-123"
assert call_arg.model == "gpt-4"

@pytest.mark.asyncio
async def test_async_record_usage_is_alias(self):
client = _make_client()
client.budget.record_usage = AsyncMock()

await client.async_record_usage(
UsageRecord(
user_id="u", org_id="o", provider="openai", model="gpt-4",
input_tokens=10, output_tokens=5, cost=0.0001, cost_type="external",
)
)
assert client.budget.record_usage.call_count == 1
Loading