From 7f894fe5ed92cbe4630ef73eba48a012827d24d9 Mon Sep 17 00:00:00 2001 From: Shaishav Pidadi Date: Fri, 17 Apr 2026 17:07:18 -0400 Subject: [PATCH] =?UTF-8?q?feat(python-sdk):=20add=20budget=5Fcheck()=20an?= =?UTF-8?q?d=20BudgetResult=20=E2=80=94=20GOV-38?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add BudgetResult model with allowed, remaining_tokens, limit, warning_threshold_hit (auto-set when < 10% budget remains) - Add budget_check(org_id, user_id, estimated_tokens) and async variant on GovernsAIClient - Add async_record_usage() alias - Fix missing Optional/Dict/Any imports in exceptions/precheck.py - 12 unit tests covering allowed/denied/warning/async variants Refs: GOV-38 --- governs_ai/client.py | 51 +++++++- governs_ai/exceptions/precheck.py | 1 + governs_ai/models/__init__.py | 3 +- governs_ai/models/budget.py | 49 ++++++++ tests/test_budget_check.py | 196 ++++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 5 deletions(-) create mode 100644 tests/test_budget_check.py diff --git a/governs_ai/client.py b/governs_ai/client.py index d00ba71..d70cff2 100644 --- a/governs_ai/client.py +++ b/governs_ai/client.py @@ -14,6 +14,7 @@ PrecheckRequest, PrecheckResponse, BudgetContext, + BudgetResult, UsageRecord, ConfirmationRequest, ConfirmationResponse, @@ -201,17 +202,59 @@ async def get_budget_context(self, user_id: str) -> BudgetContext: """ return await self.budget.get_budget_context(user_id) - async def record_usage(self, usage: Union[UsageRecord, Dict[str, Any]]) -> None: + async def budget_check( + self, + org_id: str, + user_id: str, + estimated_tokens: int, + ) -> BudgetResult: + """Check whether a user is within token budget. + + Returns a :class:`BudgetResult` with ``allowed``, ``remaining_tokens``, + ``limit``, and ``warning_threshold_hit`` (True when < 10% remaining). + + Example:: + + result = await client.budget_check( + org_id="org-123", user_id="user-456", estimated_tokens=1000 + ) + if not result.allowed: + raise RuntimeError("Budget exceeded") """ - Record usage for a user. + response = await self.budget.http_client.get( + "/api/v1/budget/context", + params={"orgId": org_id, "userId": user_id, "estimatedTokens": estimated_tokens}, + ) + return BudgetResult.from_dict(response.data) - Args: - usage: Usage record or dictionary + async def async_budget_check( + self, + org_id: str, + user_id: str, + estimated_tokens: int, + ) -> BudgetResult: + """Async variant of :meth:`budget_check` (identical — both are coroutines).""" + return await self.budget_check(org_id, user_id, estimated_tokens) + + async def record_usage(self, usage: Union[UsageRecord, Dict[str, Any]]) -> None: + """Record token and cost usage for a model request. + + Example:: + + await client.record_usage(UsageRecord( + user_id="user-123", org_id="org-456", provider="openai", + model="gpt-4o-mini", input_tokens=100, output_tokens=80, + cost=0.0017, cost_type="external", + )) """ if isinstance(usage, dict): usage = UsageRecord.from_dict(usage) await self.budget.record_usage(usage) + async def async_record_usage(self, usage: Union[UsageRecord, Dict[str, Any]]) -> None: + """Async variant of :meth:`record_usage` (identical — both are coroutines).""" + return await self.record_usage(usage) + async def create_confirmation( self, request_type: str, diff --git a/governs_ai/exceptions/precheck.py b/governs_ai/exceptions/precheck.py index 37e4843..1a7df0f 100644 --- a/governs_ai/exceptions/precheck.py +++ b/governs_ai/exceptions/precheck.py @@ -4,6 +4,7 @@ Precheck-specific exceptions. """ +from typing import Optional, Dict, Any from .base import GovernsAIError diff --git a/governs_ai/models/__init__.py b/governs_ai/models/__init__.py index 174ba48..85ef94c 100644 --- a/governs_ai/models/__init__.py +++ b/governs_ai/models/__init__.py @@ -5,7 +5,7 @@ """ from .precheck import PrecheckRequest, PrecheckResponse, Decision -from .budget import BudgetContext, UsageRecord, BudgetStatus +from .budget import BudgetContext, UsageRecord, BudgetStatus, BudgetResult from .confirmation import ConfirmationRequest, ConfirmationResponse from .health import HealthStatus from .context import ( @@ -46,6 +46,7 @@ "BudgetContext", "UsageRecord", "BudgetStatus", + "BudgetResult", "ConfirmationRequest", "ConfirmationResponse", "HealthStatus", diff --git a/governs_ai/models/budget.py b/governs_ai/models/budget.py index ac0e51f..209b901 100644 --- a/governs_ai/models/budget.py +++ b/governs_ai/models/budget.py @@ -43,6 +43,55 @@ def to_dict(self) -> Dict[str, Any]: } +@dataclass +class BudgetResult: + """Result of a budget_check() call. + + Example:: + + result = await client.budget_check( + org_id="org-123", user_id="user-456", estimated_tokens=1000 + ) + if not result.allowed: + raise RuntimeError("Budget exceeded") + if result.warning_threshold_hit: + logger.warning("Less than 10% budget remaining") + """ + allowed: bool + remaining_tokens: int + limit: int + warning_threshold_hit: bool = False + reason: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BudgetResult": + """Create from dict; auto-computes warning_threshold_hit when < 10% budget remains.""" + limit = int(data.get("limit", data.get("monthly_limit", 0))) + remaining = int( + data.get("remaining_tokens", + data.get("remaining_budget", + data.get("remainingBudget", 0))) + ) + warning = limit > 0 and (remaining / limit) < 0.10 + return cls( + allowed=data.get("allowed", remaining > 0), + remaining_tokens=remaining, + limit=limit, + warning_threshold_hit=data.get("warningThresholdHit", warning), + reason=data.get("reason"), + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "allowed": self.allowed, + "remainingTokens": self.remaining_tokens, + "limit": self.limit, + "warningThresholdHit": self.warning_threshold_hit, + "reason": self.reason, + } + + @dataclass class BudgetStatus: """Budget status for a user.""" diff --git a/tests/test_budget_check.py b/tests/test_budget_check.py new file mode 100644 index 0000000..3b8099c --- /dev/null +++ b/tests/test_budget_check.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 GovernsAI. All rights reserved. +"""Unit tests for record_usage() and budget_check().""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from governs_ai.models.budget import BudgetResult, UsageRecord +from governs_ai.client import GovernsAIClient, GovernsAIConfig + + +def _make_client() -> GovernsAIClient: + with patch("governs_ai.client.HTTPClient"): + client = GovernsAIClient( + api_key="test-key", org_id="org-123", base_url="https://api.test" + ) + return client + + +class TestBudgetResult: + """Unit tests for BudgetResult model.""" + + def test_from_dict_allowed(self): + result = BudgetResult.from_dict({ + "allowed": True, + "remaining_tokens": 9000, + "limit": 10000, + }) + assert result.allowed is True + assert result.remaining_tokens == 9000 + assert result.limit == 10000 + assert result.warning_threshold_hit is False + + def test_from_dict_denied_when_budget_exceeded(self): + result = BudgetResult.from_dict({ + "allowed": False, + "remaining_tokens": 0, + "limit": 10000, + "reason": "over_budget", + }) + assert result.allowed is False + assert result.reason == "over_budget" + + def test_warning_threshold_hit_when_below_10_percent(self): + result = BudgetResult.from_dict({ + "allowed": True, + "remaining_tokens": 500, + "limit": 10000, + }) + assert result.warning_threshold_hit is True + + def test_warning_threshold_not_hit_at_exactly_10_percent(self): + result = BudgetResult.from_dict({ + "allowed": True, + "remaining_tokens": 1000, + "limit": 10000, + }) + # 1000/10000 == 0.10, not < 0.10 + assert result.warning_threshold_hit is False + + def test_from_dict_camelcase_keys(self): + result = BudgetResult.from_dict({ + "allowed": True, + "remainingBudget": 8000, + "limit": 10000, + }) + assert result.remaining_tokens == 8000 + + +class TestBudgetCheckMethod: + """Unit tests for GovernsAIClient.budget_check().""" + + @pytest.mark.asyncio + async def test_budget_check_allowed(self): + client = _make_client() + mock_response = MagicMock() + mock_response.data = { + "allowed": True, + "remaining_tokens": 9000, + "limit": 10000, + } + client.budget.http_client.get = AsyncMock(return_value=mock_response) + + result = await client.budget_check( + org_id="org-123", user_id="user-456", estimated_tokens=500 + ) + + assert isinstance(result, BudgetResult) + assert result.allowed is True + assert result.warning_threshold_hit is False + client.budget.http_client.get.assert_called_once_with( + "/api/v1/budget/context", + params={"orgId": "org-123", "userId": "user-456", "estimatedTokens": 500}, + ) + + @pytest.mark.asyncio + async def test_budget_check_denied_when_over_budget(self): + client = _make_client() + mock_response = MagicMock() + mock_response.data = { + "allowed": False, + "remaining_tokens": 0, + "limit": 10000, + "reason": "over_budget", + } + client.budget.http_client.get = AsyncMock(return_value=mock_response) + + result = await client.budget_check( + org_id="org-123", user_id="user-456", estimated_tokens=1000 + ) + assert result.allowed is False + assert result.reason == "over_budget" + + @pytest.mark.asyncio + async def test_budget_check_warning_threshold_hit(self): + """warning_threshold_hit=True when remaining < 10% of limit.""" + client = _make_client() + mock_response = MagicMock() + mock_response.data = { + "allowed": True, + "remaining_tokens": 500, + "limit": 10000, + } + client.budget.http_client.get = AsyncMock(return_value=mock_response) + + result = await client.budget_check( + org_id="org-123", user_id="user-456", estimated_tokens=100 + ) + assert result.allowed is True + assert result.warning_threshold_hit is True + + @pytest.mark.asyncio + async def test_async_budget_check_is_alias(self): + client = _make_client() + mock_response = MagicMock() + mock_response.data = {"allowed": True, "remaining_tokens": 5000, "limit": 10000} + client.budget.http_client.get = AsyncMock(return_value=mock_response) + + result = await client.async_budget_check( + org_id="org-123", user_id="user-456", estimated_tokens=200 + ) + assert isinstance(result, BudgetResult) + + +class TestRecordUsageMethod: + """Unit tests for GovernsAIClient.record_usage().""" + + @pytest.mark.asyncio + async def test_record_usage_correct_payload(self): + client = _make_client() + client.budget.record_usage = AsyncMock() + + record = UsageRecord( + user_id="user-123", + org_id="org-456", + provider="openai", + model="gpt-4o-mini", + input_tokens=100, + output_tokens=80, + cost=0.0017, + cost_type="external", + ) + await client.record_usage(record) + client.budget.record_usage.assert_called_once_with(record) + + @pytest.mark.asyncio + async def test_record_usage_dict_converted_to_usage_record(self): + client = _make_client() + client.budget.record_usage = AsyncMock() + + await client.record_usage({ + "userId": "user-123", + "orgId": "org-456", + "provider": "openai", + "model": "gpt-4", + "inputTokens": 50, + "outputTokens": 30, + "cost": 0.001, + "costType": "external", + }) + call_arg = client.budget.record_usage.call_args[0][0] + assert isinstance(call_arg, UsageRecord) + assert call_arg.user_id == "user-123" + assert call_arg.model == "gpt-4" + + @pytest.mark.asyncio + async def test_async_record_usage_is_alias(self): + client = _make_client() + client.budget.record_usage = AsyncMock() + + await client.async_record_usage( + UsageRecord( + user_id="u", org_id="o", provider="openai", model="gpt-4", + input_tokens=10, output_tokens=5, cost=0.0001, cost_type="external", + ) + ) + assert client.budget.record_usage.call_count == 1