diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..6545b4d --- /dev/null +++ b/conftest.py @@ -0,0 +1,7 @@ +import sys +from pathlib import Path + +# Ensure src/ package takes precedence over the legacy root-level governs_ai/ directory +_src = str(Path(__file__).parent / "src") +if _src not in sys.path: + sys.path.insert(0, _src) diff --git a/governs_ai/exceptions/precheck.py b/governs_ai/exceptions/precheck.py index 37e4843..4b6eb34 100644 --- a/governs_ai/exceptions/precheck.py +++ b/governs_ai/exceptions/precheck.py @@ -4,6 +4,8 @@ Precheck-specific exceptions. """ +from typing import Any, Dict, Optional + from .base import GovernsAIError diff --git a/pyproject.toml b/pyproject.toml index 79083fe..baefc8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,52 +1,42 @@ [build-system] -requires = ["setuptools>=45", "wheel"] +requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "governs-ai-sdk" -version = "1.0.0" -description = "Python SDK for GovernsAI - AI governance and compliance platform" +version = "0.1.0-alpha.1" +description = "GovernsAI Python SDK" readme = "README.md" requires-python = ">=3.8" -license = {text = "MIT"} -authors = [ - {name = "GovernsAI", email = "support@governs.ai"}, -] +license = { text = "Elastic-2.0" } classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", + "License :: Other/Proprietary License", + "Operating System :: OS Independent", ] dependencies = [ - "requests>=2.25.0", - "pydantic>=1.8.0", - "typing-extensions>=3.10.0", - "aiohttp>=3.8.0", - "asyncio-throttle>=1.0.0", + "httpx>=0.24.0", + "pydantic>=2.0.0", ] [project.optional-dependencies] dev = [ - "pytest>=6.0", - "pytest-asyncio>=0.18.0", - "black>=22.0", - "flake8>=4.0", - "mypy>=0.950", + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "pytest-httpx>=0.21.0", + "black>=23.0.0", + "isort>=5.12.0", + "mypy>=1.0.0", ] -[tool.black] -line-length = 88 -target-version = ['py38'] +[tool.setuptools.packages.find] +where = ["src"] -[tool.mypy] -python_version = "3.8" -warn_return_any = true -warn_unused_configs = true -disallow_untyped_defs = true +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +pythonpath = ["src"] +addopts = "--import-mode=importlib" +markers = [ + "integration: requires running local services (deselect with '-m not integration')", +] diff --git a/src/governs_ai/__init__.py b/src/governs_ai/__init__.py new file mode 100644 index 0000000..a5299e6 --- /dev/null +++ b/src/governs_ai/__init__.py @@ -0,0 +1,10 @@ +from .client import GovernsAIClient, GovernsAIError, PrecheckError +from .types import BudgetResult, PrecheckResult + +__all__ = [ + "GovernsAIClient", + "GovernsAIError", + "PrecheckError", + "PrecheckResult", + "BudgetResult", +] diff --git a/src/governs_ai/client.py b/src/governs_ai/client.py new file mode 100644 index 0000000..1831e28 --- /dev/null +++ b/src/governs_ai/client.py @@ -0,0 +1,337 @@ +import asyncio +import time +from typing import Any, Dict, List, Optional, Union + +import httpx + +from .types import BudgetResult, PrecheckResult + + +class GovernsAIError(Exception): + """Base error for GovernsAI SDK""" + + def __init__( + self, + message: str, + status_code: Optional[int] = None, + response: Optional[Any] = None, + retryable: bool = False, + ): + super().__init__(message) + self.status_code = status_code + self.response = response + self.retryable = retryable + + +class PrecheckError(GovernsAIError): + """Error during precheck operation""" + + pass + + +class GovernsAIClient: + """ + Main SDK client for GovernsAI. + """ + + def __init__( + self, + api_key: str, + base_url: str = "https://api.governs.ai", + org_id: Optional[str] = None, + timeout: float = 30.0, + max_retries: int = 3, + ): + self.api_key = api_key + self.base_url = base_url.rstrip("/") + self.org_id = org_id + self.timeout = timeout + self.max_retries = max_retries + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "X-Governs-Key": self.api_key, + "Content-Type": "application/json", + "X-SDK-Language": "python", + } + + def __repr__(self): + return f"" + + def _get_payload( + self, content: str, tool: str, org_id: Optional[str] + ) -> Dict[str, Any]: + return { + "tool": tool, + "raw_text": content, + "org_id": org_id or self.org_id, + "scope": "net.external", + } + + def _parse_response( + self, response: httpx.Response, latency_ms: float + ) -> PrecheckResult: + if response.status_code >= 400: + try: + error_data = response.json() + message = error_data.get("error") or error_data.get("message") + except Exception: + message = None + + if not message: + message = f"HTTP {response.status_code} {response.reason_phrase}" + + retryable = response.status_code >= 500 or response.status_code == 429 + raise PrecheckError( + message, + status_code=response.status_code, + response=response, + retryable=retryable, + ) + + data = response.json() + return PrecheckResult( + decision=data.get("decision", "deny"), + redacted_content=data.get("redacted_content") + or data.get("content", {}).get("raw_text"), + reasons=data.get("reasons", []), + latency_ms=latency_ms, + ) + + def precheck( + self, + content: str, + tool: str, + org_id: Optional[str] = None, + ) -> PrecheckResult: + """ + Check a request for governance compliance. + """ + payload = self._get_payload(content, tool, org_id) + start_time = time.time() + + last_error_msg = "Unknown error" + for attempt in range(self.max_retries + 1): + try: + with httpx.Client(timeout=self.timeout) as client: + response = client.post( + f"{self.base_url}/api/v1/precheck", + json=payload, + headers=self.headers, + ) + + if response.status_code >= 500 or response.status_code == 429: + last_error_msg = ( + f"HTTP {response.status_code} {response.reason_phrase}" + ) + if attempt < self.max_retries: + time.sleep(2**attempt) + continue + else: + break + + latency_ms = (time.time() - start_time) * 1000 + return self._parse_response(response, latency_ms) + except (httpx.RequestError, httpx.TimeoutException) as e: + last_error_msg = str(e) + if attempt < self.max_retries: + time.sleep(2**attempt) + continue + + raise PrecheckError(f"Max retries exceeded: {last_error_msg}") + + async def async_precheck( + self, + content: str, + tool: str, + org_id: Optional[str] = None, + ) -> PrecheckResult: + """ + Async version of precheck. + """ + payload = self._get_payload(content, tool, org_id) + start_time = time.time() + + last_error_msg = "Unknown error" + for attempt in range(self.max_retries + 1): + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/api/v1/precheck", + json=payload, + headers=self.headers, + ) + + if response.status_code >= 500 or response.status_code == 429: + last_error_msg = ( + f"HTTP {response.status_code} {response.reason_phrase}" + ) + if attempt < self.max_retries: + await asyncio.sleep(2**attempt) + continue + else: + break + + latency_ms = (time.time() - start_time) * 1000 + return self._parse_response(response, latency_ms) + except (httpx.RequestError, httpx.TimeoutException) as e: + last_error_msg = str(e) + if attempt < self.max_retries: + await asyncio.sleep(2**attempt) + continue + + raise PrecheckError(f"Max retries exceeded: {last_error_msg}") + + # ------------------------------------------------------------------ + # 1.4c — record_usage() + # ------------------------------------------------------------------ + + def record_usage( + self, + org_id: str, + user_id: str, + tokens: int, + model: str, + *, + provider: str = "openai", + ) -> None: + """Record token usage for a model request. + + Example:: + + client.record_usage( + org_id="org-1", user_id="user-123", + tokens=180, model="gpt-4o-mini", + ) + """ + payload: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "inputTokens": tokens, + "outputTokens": 0, + "model": model, + "provider": provider, + } + with httpx.Client(timeout=self.timeout) as http: + resp = http.post( + f"{self.base_url}/api/v1/usage", + json=payload, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"record_usage failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + + async def async_record_usage( + self, + org_id: str, + user_id: str, + tokens: int, + model: str, + *, + provider: str = "openai", + ) -> None: + """Async variant of :meth:`record_usage`.""" + payload: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "inputTokens": tokens, + "outputTokens": 0, + "model": model, + "provider": provider, + } + async with httpx.AsyncClient(timeout=self.timeout) as http: + resp = await http.post( + f"{self.base_url}/api/v1/usage", + json=payload, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"record_usage failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + + # ------------------------------------------------------------------ + # 1.4c — budget_check() + # ------------------------------------------------------------------ + + def budget_check( + self, + org_id: str, + user_id: str, + estimated_tokens: int = 0, + ) -> BudgetResult: + """Check whether the user/org is within budget. + + Example:: + + budget = client.budget_check(org_id="org-1", user_id="u1", estimated_tokens=500) + if not budget.allowed: + raise RuntimeError("Budget exceeded") + """ + params: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "estimatedTokens": estimated_tokens, + } + with httpx.Client(timeout=self.timeout) as http: + resp = http.get( + f"{self.base_url}/api/v1/budget/context", + params=params, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"budget_check failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + data = resp.json() + limit = data.get("limit", data.get("monthly_limit", 0)) + remaining = data.get("remaining_tokens", data.get("remaining", limit)) + allowed = data.get("allowed", remaining > 0) + warning_threshold_hit = limit > 0 and (remaining / limit) < 0.10 + return BudgetResult( + allowed=allowed, + remaining_tokens=int(remaining), + limit=int(limit), + warning_threshold_hit=warning_threshold_hit, + reason=data.get("reason", ""), + ) + + async def async_budget_check( + self, + org_id: str, + user_id: str, + estimated_tokens: int = 0, + ) -> BudgetResult: + """Async variant of :meth:`budget_check`.""" + params: Dict[str, Any] = { + "orgId": org_id or self.org_id, + "userId": user_id, + "estimatedTokens": estimated_tokens, + } + async with httpx.AsyncClient(timeout=self.timeout) as http: + resp = await http.get( + f"{self.base_url}/api/v1/budget/context", + params=params, + headers=self.headers, + ) + if resp.status_code >= 400: + raise GovernsAIError( + f"budget_check failed with HTTP {resp.status_code}: {resp.text}", + status_code=resp.status_code, + ) + data = resp.json() + limit = data.get("limit", data.get("monthly_limit", 0)) + remaining = data.get("remaining_tokens", data.get("remaining", limit)) + allowed = data.get("allowed", remaining > 0) + warning_threshold_hit = limit > 0 and (remaining / limit) < 0.10 + return BudgetResult( + allowed=allowed, + remaining_tokens=int(remaining), + limit=int(limit), + warning_threshold_hit=warning_threshold_hit, + reason=data.get("reason", ""), + ) diff --git a/src/governs_ai/types.py b/src/governs_ai/types.py new file mode 100644 index 0000000..41113e3 --- /dev/null +++ b/src/governs_ai/types.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class PrecheckResult: + decision: str + redacted_content: Optional[str] = None + reasons: List[str] = field(default_factory=list) + latency_ms: float = 0.0 + + +@dataclass +class BudgetResult: + """Result of a budget_check call. + + Example:: + + budget = client.budget_check(org_id="org-1", user_id="user-1", estimated_tokens=500) + if not budget.allowed: + raise RuntimeError("Budget exceeded") + if budget.warning_threshold_hit: + logger.warning("Less than 10% budget remaining") + """ + + allowed: bool + remaining_tokens: int + limit: int + warning_threshold_hit: bool + reason: str = "" diff --git a/tests/__init__.py b/tests/__init__.py index 920b462..e69de29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +0,0 @@ -""" -Tests for the GovernsAI Python SDK. -""" diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_precheck_integration.py b/tests/integration/test_precheck_integration.py new file mode 100644 index 0000000..a19345d --- /dev/null +++ b/tests/integration/test_precheck_integration.py @@ -0,0 +1,36 @@ +"""Integration tests for precheck() — require local precheck service on localhost:8000.""" + +import os + +import pytest + +from governs_ai.client import GovernsAIClient + + +@pytest.fixture +def client(): + return GovernsAIClient( + api_key=os.getenv("GOVERNS_API_KEY", "dev-key"), + base_url=os.getenv("GOVERNS_BASE_URL", "http://localhost:8000"), + org_id=os.getenv("GOVERNS_ORG_ID", "org-dev"), + ) + + +@pytest.mark.integration +def test_precheck_returns_real_decision(client): + """precheck() against local service returns a valid decision.""" + result = client.precheck(content="Hello from integration test", tool="model.chat") + assert result.decision in ("allow", "deny", "transform", "confirm") + assert isinstance(result.redacted_content, str) + assert isinstance(result.reasons, list) + assert result.latency_ms > 0 + + +@pytest.mark.integration +async def test_async_precheck_returns_real_decision(client): + """async_precheck() against local service returns a valid decision.""" + result = await client.async_precheck( + content="Hello from async integration test", tool="model.chat" + ) + assert result.decision in ("allow", "deny", "transform", "confirm") + assert result.latency_ms > 0 diff --git a/tests/integration_test.py b/tests/integration_test.py new file mode 100644 index 0000000..996633d --- /dev/null +++ b/tests/integration_test.py @@ -0,0 +1,28 @@ +import asyncio +import os + +from governs_ai.client import GovernsAIClient + + +async def main(): + api_key = os.getenv("GOVERNS_API_KEY", "test-key") + base_url = os.getenv("GOVERNS_BASE_URL", "http://localhost:8000") + org_id = os.getenv("GOVERNS_ORG_ID", "test-org") + + client = GovernsAIClient(api_key=api_key, base_url=base_url, org_id=org_id) + + print(f"Checking precheck against {base_url}...") + try: + result = await client.async_precheck( + content="Hello, is this safe?", tool="chat" + ) + print(f"Decision: {result.decision}") + print(f"Redacted: {result.redacted_content}") + print(f"Reasons: {result.reasons}") + print(f"Latency: {result.latency_ms:.2f}ms") + except Exception as e: + print(f"Precheck failed: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_client.py b/tests/test_client.py index 5460049..9147eb7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,158 +1,99 @@ -""" -Tests for the GovernsAI client. -""" +import json +import time +from unittest.mock import patch import pytest -import asyncio -from unittest.mock import AsyncMock, MagicMock -from governs_ai import GovernsAIClient, GovernsAIConfig -from governs_ai.models import PrecheckRequest, PrecheckResponse, Decision - - -class TestGovernsAIClient: - """Test cases for GovernsAIClient.""" - - @pytest.fixture - def mock_http_client(self): - """Mock HTTP client for testing.""" - mock_client = AsyncMock() - mock_response = MagicMock() - mock_response.is_success = True - mock_response.data = {"status": "healthy"} - mock_client.get.return_value = mock_response - return mock_client - - @pytest.fixture - def client(self, mock_http_client): - """Create a test client.""" - config = GovernsAIConfig( - api_key="test-key", - org_id="test-org", - http_client=mock_http_client, - ) - return GovernsAIClient(config=config) - - def test_default_base_url_is_production(self): - """Default base URL should target managed API, not localhost.""" - config = GovernsAIConfig(api_key="test-key", org_id="test-org") - assert config.base_url == "https://api.governsai.com" - - @pytest.mark.asyncio - async def test_test_connection_success(self, client, mock_http_client): - """Test successful connection test.""" - result = await client.test_connection() - assert result is True - mock_http_client.get.assert_called_once_with("/api/v1/health") - - @pytest.mark.asyncio - async def test_test_connection_failure(self, client, mock_http_client): - """Test failed connection test.""" - mock_http_client.get.side_effect = Exception("Connection failed") - result = await client.test_connection() - assert result is False - - @pytest.mark.asyncio - async def test_precheck_request(self, client, mock_http_client): - """Test precheck request.""" - # Mock precheck response - mock_response = MagicMock() - mock_response.is_success = True - mock_response.data = { - "decision": "allow", - "reasons": [], - "requiresConfirmation": False - } - mock_http_client.post.return_value = mock_response - - result = await client.precheck_request( - tool="model.chat", - scope="net.external", - raw_text="Hello", - payload={"messages": []}, - tags=["test"], - user_id="user-123" - ) - - assert isinstance(result, PrecheckResponse) - assert result.decision == Decision.ALLOW - mock_http_client.post.assert_called_once() - - @pytest.mark.asyncio - async def test_get_budget_context(self, client, mock_http_client): - """Test get budget context.""" - mock_response = MagicMock() - mock_response.is_success = True - mock_response.data = { - "monthlyLimit": 1000.0, - "currentSpend": 250.0, - "remainingBudget": 750.0, - "currency": "USD" - } - mock_http_client.get.return_value = mock_response - - result = await client.get_budget_context("user-123") - - assert result.monthly_limit == 1000.0 - assert result.current_spend == 250.0 - assert result.remaining_budget == 750.0 - assert result.currency == "USD" - - @pytest.mark.asyncio - async def test_record_usage(self, client, mock_http_client): - """Test record usage.""" - mock_response = MagicMock() - mock_response.is_success = True - mock_http_client.post.return_value = mock_response - - usage_data = { - "user_id": "user-123", - "org_id": "org-456", - "provider": "openai", - "model": "gpt-4", - "input_tokens": 100, - "output_tokens": 50, - "cost": 0.15, - "cost_type": "external" - } - - await client.record_usage(usage_data) - - mock_http_client.post.assert_called_once() - - @pytest.mark.asyncio - async def test_get_health_status(self, client, mock_http_client): - """Test get health status.""" - mock_response = MagicMock() - mock_response.is_success = True - mock_response.data = { - "status": "healthy", - "services": {"api": "healthy", "db": "healthy"}, - "version": "1.0.0" - } - mock_http_client.get.return_value = mock_response - - result = await client.get_health_status() - - assert result.status == "healthy" - assert "api" in result.services - assert result.version == "1.0.0" - - def test_update_config(self, client): - """Test update configuration.""" - new_config = {"timeout": 60000, "retries": 5} - client.update_config(new_config) - - assert client.config.timeout == 60000 - assert client.config.retries == 5 - - def test_get_config(self, client): - """Test get configuration.""" - config = client.get_config() - assert isinstance(config, GovernsAIConfig) - assert config.api_key == "test-key" - assert config.org_id == "test-org" - - def test_context_and_document_clients_available(self, client): - """Feature parity clients should be initialized on main client.""" - assert client.context is not None - assert client.documents is not None + +from governs_ai.client import GovernsAIClient, PrecheckError +from governs_ai.types import PrecheckResult + + +@pytest.fixture +def client(): + return GovernsAIClient(api_key="test-key", org_id="test-org") + + +def test_client_init(client): + assert client.api_key == "test-key" + assert client.org_id == "test-org" + assert client.base_url == "https://api.governs.ai" + + +def test_precheck_payload(client, httpx_mock): + httpx_mock.add_response( + method="POST", + url="https://api.governs.ai/api/v1/precheck", + json={"decision": "allow", "reasons": []}, + status_code=200, + ) + + result = client.precheck(content="Hello", tool="test-tool") + + assert result.decision == "allow" + + # Verify request payload + request = httpx_mock.get_request() + assert request is not None + data = json.loads(request.read().decode()) + assert data["tool"] == "test-tool" + assert data["raw_text"] == "Hello" + assert data["org_id"] == "test-org" + + +@pytest.mark.asyncio +async def test_async_precheck_payload(client, httpx_mock): + httpx_mock.add_response( + method="POST", + url="https://api.governs.ai/api/v1/precheck", + json={"decision": "allow", "reasons": []}, + status_code=200, + ) + + result = await client.async_precheck(content="Hello", tool="test-tool") + assert result.decision == "allow" + + # Verify request payload + request = httpx_mock.get_request() + assert request is not None + data = json.loads(request.read().decode()) + assert data["tool"] == "test-tool" + + +def test_precheck_retry_on_5xx(client, httpx_mock): + # Add two 500 errors and then a success + httpx_mock.add_response(status_code=500) + httpx_mock.add_response(status_code=500) + httpx_mock.add_response(status_code=200, json={"decision": "allow"}) + + with patch("time.sleep"): + result = client.precheck(content="Hello", tool="test-tool") + + assert result.decision == "allow" + assert len(httpx_mock.get_requests()) == 3 + + +@pytest.mark.asyncio +async def test_async_precheck_retry_on_5xx(client, httpx_mock): + # Add two 500 errors and then a success + httpx_mock.add_response(status_code=500) + httpx_mock.add_response(status_code=500) + httpx_mock.add_response(status_code=200, json={"decision": "allow"}) + + with patch("asyncio.sleep"): + result = await client.async_precheck(content="Hello", tool="test-tool") + + assert result.decision == "allow" + assert len(httpx_mock.get_requests()) == 3 + + +def test_precheck_max_retries_exceeded(client, httpx_mock): + # Add four 500 errors (max_retries is 3, so 4 attempts total) + for _ in range(4): + httpx_mock.add_response(status_code=500) + + with patch("time.sleep"): + with pytest.raises(PrecheckError) as excinfo: + client.precheck(content="Hello", tool="test-tool") + + assert "Max retries exceeded" in str(excinfo.value) + assert len(httpx_mock.get_requests()) == 4 diff --git a/tests/test_record_usage_budget.py b/tests/test_record_usage_budget.py new file mode 100644 index 0000000..c269ceb --- /dev/null +++ b/tests/test_record_usage_budget.py @@ -0,0 +1,82 @@ +"""Unit tests for record_usage() and budget_check().""" + +import json +import pytest +import respx +import httpx + +from governs_ai import GovernsAIClient, BudgetResult + +BASE = "https://api.governs.ai" + + +@pytest.fixture +def client(): + return GovernsAIClient(api_key="test-key", org_id="org-test") + + +@respx.mock +def test_record_usage_sends_correct_payload(client): + route = respx.post(f"{BASE}/api/v1/usage").mock( + return_value=httpx.Response(200, json={"accepted": True}) + ) + client.record_usage(org_id="org-1", user_id="user-123", tokens=100, model="gpt-4o") + body = json.loads(route.calls[0].request.content) + assert body["orgId"] == "org-1" + assert body["userId"] == "user-123" + assert body["inputTokens"] == 100 + assert body["model"] == "gpt-4o" + + +@respx.mock +async def test_async_record_usage_sends_correct_payload(client): + route = respx.post(f"{BASE}/api/v1/usage").mock( + return_value=httpx.Response(200, json={"accepted": True}) + ) + await client.async_record_usage( + org_id="org-1", user_id="user-123", tokens=50, model="gpt-4o-mini" + ) + body = json.loads(route.calls[0].request.content) + assert body["userId"] == "user-123" + assert body["inputTokens"] == 50 + + +@respx.mock +def test_budget_check_allowed(client): + respx.get(f"{BASE}/api/v1/budget/context").mock( + return_value=httpx.Response( + 200, + json={"allowed": True, "remaining_tokens": 9000, "limit": 10000, "reason": ""}, + ) + ) + result = client.budget_check(org_id="org-1", user_id="user-1") + assert isinstance(result, BudgetResult) + assert result.allowed is True + assert result.warning_threshold_hit is False + + +@respx.mock +def test_budget_check_denied_when_over_budget(client): + respx.get(f"{BASE}/api/v1/budget/context").mock( + return_value=httpx.Response( + 200, + json={"allowed": False, "remaining_tokens": 0, "limit": 10000, "reason": "over_budget"}, + ) + ) + result = client.budget_check(org_id="org-1", user_id="user-1") + assert result.allowed is False + assert result.reason == "over_budget" + + +@respx.mock +def test_budget_check_warning_threshold_hit(client): + """warning_threshold_hit=True when remaining < 10% of limit.""" + respx.get(f"{BASE}/api/v1/budget/context").mock( + return_value=httpx.Response( + 200, + json={"allowed": True, "remaining_tokens": 500, "limit": 10000, "reason": ""}, + ) + ) + result = client.budget_check(org_id="org-1", user_id="user-1") + assert result.allowed is True + assert result.warning_threshold_hit is True