diff --git a/CHANGELOG.md b/CHANGELOG.md index f8c83fd..2128a4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.1.1] - 2026-03-26 + +### Added + +- **OpenAI Agents SDK Adapter** — `ContractRunHooks(RunHooks)` for effect gating via `on_tool_start`, token tracking via `on_llm_end`, postcondition evaluation via `on_agent_end`. Pinned to `openai-agents==0.8.4` +- **Claude Agent SDK Adapter** — `ContractHooks` with structured deny via PreToolUse (not exception). Cost/token extraction from ResultMessage. Pinned to `claude-agent-sdk==0.1.50` (Python 3.10+) +- **Precondition Evaluation** — `contract.preconditions[]` evaluated on input BEFORE agent runs. Reuses CEL-like expression evaluator. `PreconditionError` blocks execution before tokens are spent. Wired into `ContractEnforcer.check_preconditions()` and `@enforce_contract` decorator +- **GitHub Action** — `pyyush/agentcontracts@v0.1.1` composite action for CI contract validation +- **README Badge** — PyPI version and CI status badges +- 35 new tests (188 total) + ## [0.1.0] - 2026-03-25 First release. YAML spec + Python SDK for production agent reliability. @@ -32,4 +43,5 @@ First release. YAML spec + Python SDK for production agent reliability. - **Specification** — Human-readable spec narrative (`SPECIFICATION.md`) - **Examples** — Reference contracts for all 3 tiers +[0.1.1]: https://github.com/pyyush/agentcontracts/releases/tag/v0.1.1 [0.1.0]: https://github.com/pyyush/agentcontracts/releases/tag/v0.1.0 diff --git a/CLAUDE.md b/CLAUDE.md index 20c9714..5874091 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,6 +28,8 @@ src/agent_contracts/ langchain.py LangChain CallbackHandler crewai.py CrewAI ContractGuard pydantic_ai.py Pydantic AI ContractMiddleware + openai_agents.py OpenAI Agents SDK RunHooks + claude_agent.py Claude Agent SDK ContractHooks examples/ Reference contracts (Tier 0, 1, 2) tests/ pytest test suite ``` diff --git a/README.md b/README.md index d773b67..ec2c73c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Agent Contracts +[![PyPI](https://img.shields.io/pypi/v/aicontracts)](https://pypi.org/project/aicontracts/) +[![CI](https://github.com/pyyush/agentcontracts/actions/workflows/ci.yml/badge.svg)](https://github.com/pyyush/agentcontracts/actions/workflows/ci.yml) + **YAML spec + validation SDK for production agent reliability.** Cost control, tool-use security, and audit trails in under 30 minutes of integration. Works with any framework. Enforces at the runtime layer, not via prompts. @@ -98,6 +101,20 @@ middleware = ContractMiddleware.from_file("AGENT_CONTRACT.yaml") result = await middleware.run(agent, prompt) ``` +**OpenAI Agents SDK:** +```python +from agent_contracts.adapters.openai_agents import ContractRunHooks +hooks = ContractRunHooks.from_file("AGENT_CONTRACT.yaml") +result = await Runner.run(agent, "prompt", run_hooks=[hooks]) +``` + +**Claude Agent SDK:** +```python +from agent_contracts.adapters.claude_agent import ContractHooks +hooks = ContractHooks.from_file("AGENT_CONTRACT.yaml") +# Pass hooks.pre_tool_use to ClaudeAgentOptions +``` + ## Three Tiers Start simple, add guarantees as production demands. diff --git a/action.yml b/action.yml new file mode 100644 index 0000000..b4b1c1c --- /dev/null +++ b/action.yml @@ -0,0 +1,58 @@ +name: 'AI Contracts Validate' +description: 'Validate agent contracts against the AI Contracts spec' +branding: + icon: 'shield' + color: 'blue' + +inputs: + contract: + description: 'Path to contract YAML file(s), space-separated' + required: true + fail-on-warning: + description: 'Fail if contract has upgrade recommendations' + required: false + default: 'false' + python-version: + description: 'Python version to use' + required: false + default: '3.11' + +outputs: + outcome: + description: 'pass or fail' + value: ${{ steps.validate.outputs.outcome }} + tier: + description: 'Contract tier (0, 1, or 2)' + value: ${{ steps.validate.outputs.tier }} + +runs: + using: 'composite' + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Install aicontracts + shell: bash + run: pip install aicontracts + + - name: Validate contracts + id: validate + shell: bash + run: | + outcome="pass" + for contract in ${{ inputs.contract }}; do + echo "::group::Validating $contract" + result=$(aicontracts validate "$contract" -j) + if [ $? -eq 0 ]; then + tier=$(echo "$result" | python3 -c "import sys,json; print(json.load(sys.stdin)['tier'])") + echo "tier=$tier" >> "$GITHUB_OUTPUT" + else + outcome="fail" + fi + echo "::endgroup::" + done + echo "outcome=$outcome" >> "$GITHUB_OUTPUT" + if [ "$outcome" = "fail" ]; then + exit 1 + fi diff --git a/pyproject.toml b/pyproject.toml index 9ad175d..d54a0ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,15 @@ otel = ["opentelemetry-api>=1.20"] langchain = ["langchain-core>=0.2"] crewai = ["crewai>=0.50"] pydantic-ai = ["pydantic-ai>=0.1"] +openai = ["openai-agents==0.8.4"] +claude = ["claude-agent-sdk==0.1.50; python_version>='3.10'"] all = [ "aicontracts[otel]", "aicontracts[langchain]", "aicontracts[crewai]", "aicontracts[pydantic-ai]", + "aicontracts[openai]", + "aicontracts[claude]", ] dev = [ "pytest>=8.0", diff --git a/src/agent_contracts/__init__.py b/src/agent_contracts/__init__.py index 98a0533..7a9bd09 100644 --- a/src/agent_contracts/__init__.py +++ b/src/agent_contracts/__init__.py @@ -16,7 +16,7 @@ from agent_contracts.effects import EffectDeniedError, EffectGuard from agent_contracts.enforcer import ContractEnforcer, ContractViolation, enforce_contract from agent_contracts.loader import ContractLoadError, load_contract, validate_contract -from agent_contracts.postconditions import PostconditionError +from agent_contracts.postconditions import PostconditionError, PreconditionError from agent_contracts.tier import TierRecommendation, assess_tier, recommend_upgrades from agent_contracts.types import ( Contract, @@ -27,6 +27,7 @@ FailureModel, ObservabilityConfig, PostconditionDef, + PreconditionDef, ResourceBudgets, SLOConfig, VersioningConfig, @@ -39,6 +40,7 @@ "Contract", "ContractIdentity", "PostconditionDef", + "PreconditionDef", "EffectsAuthorized", "EffectsDeclared", "ResourceBudgets", @@ -67,6 +69,7 @@ "BudgetExceededError", # Postconditions "PostconditionError", + "PreconditionError", # Violations "ViolationEvent", "ViolationEmitter", diff --git a/src/agent_contracts/_version.py b/src/agent_contracts/_version.py index d9819ba..3cd27c7 100644 --- a/src/agent_contracts/_version.py +++ b/src/agent_contracts/_version.py @@ -1,3 +1,3 @@ """Agent Contracts version.""" -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/src/agent_contracts/adapters/__init__.py b/src/agent_contracts/adapters/__init__.py index f2850dd..cd5e596 100644 --- a/src/agent_contracts/adapters/__init__.py +++ b/src/agent_contracts/adapters/__init__.py @@ -7,4 +7,6 @@ pip install aicontracts[langchain] pip install aicontracts[crewai] pip install aicontracts[pydantic-ai] + pip install aicontracts[openai] + pip install aicontracts[claude] # Python 3.10+ """ diff --git a/src/agent_contracts/adapters/claude_agent.py b/src/agent_contracts/adapters/claude_agent.py new file mode 100644 index 0000000..688b419 --- /dev/null +++ b/src/agent_contracts/adapters/claude_agent.py @@ -0,0 +1,135 @@ +"""Claude Agent SDK adapter — contract enforcement via hooks. + +Usage (3 lines): + from agent_contracts.adapters.claude_agent import ContractHooks + hooks = ContractHooks.from_file("contract.yaml") + # Pass hooks.pre_tool_use and hooks.post_tool_use to ClaudeAgentOptions + +Requires: pip install aicontracts[claude] (Python 3.10+) + +Design: PreToolUse returns structured deny (not exception) when a tool +is unauthorized. This layers ON TOP of the SDK's own allowed_tools +mechanism — it does not replace it. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +from agent_contracts.enforcer import ContractEnforcer, ContractViolation +from agent_contracts.loader import load_contract +from agent_contracts.types import Contract +from agent_contracts.violations import ViolationEvent + + +class ContractHooks: + """Claude Agent SDK hooks that enforce an agent contract. + + Generates async callables for PreToolUse and PostToolUse that can be + passed to ClaudeAgentOptions hooks configuration. + + PreToolUse returns structured deny for unauthorized tools. + PostToolUse tracks tool calls against budget. + """ + + def __init__( + self, + contract: Contract, + *, + violation_destination: str = "stdout", + violation_callback: Optional[Callable[[ViolationEvent], None]] = None, + ) -> None: + self._enforcer = ContractEnforcer( + contract, + violation_destination=violation_destination, + violation_callback=violation_callback, + ) + + @classmethod + def from_file( + cls, + path: Union[str, Path], + *, + violation_destination: str = "stdout", + ) -> "ContractHooks": + """Create hooks from a contract YAML file.""" + contract = load_contract(path) + return cls(contract, violation_destination=violation_destination) + + @property + def enforcer(self) -> ContractEnforcer: + return self._enforcer + + @property + def violations(self) -> List[ViolationEvent]: + return self._enforcer.violations + + async def pre_tool_use( + self, + input_data: Dict[str, Any], + tool_use_id: Optional[str] = None, + context: Any = None, + ) -> Dict[str, Any]: + """PreToolUse hook — check authorization before tool executes. + + Returns structured deny if tool is not authorized. + Returns empty dict to allow execution. + """ + tool_name = input_data.get("tool_name", "") + + try: + self._enforcer.check_tool_call(tool_name) + except ContractViolation: + return { + "hookSpecificOutput": { + "hookEventName": input_data.get("hook_event_name", "PreToolUse"), + "permissionDecision": "deny", + "permissionDecisionReason": ( + f"Tool '{tool_name}' not authorized by agent contract " + f"'{self._enforcer.contract.identity.name}'" + ), + } + } + + return {} + + async def post_tool_use( + self, + input_data: Dict[str, Any], + tool_use_id: Optional[str] = None, + context: Any = None, + ) -> Dict[str, Any]: + """PostToolUse hook — observe tool completion.""" + return {} + + def get_hooks_config(self) -> Dict[str, Any]: + """Return a hooks dict suitable for ClaudeAgentOptions. + + Usage: + options = ClaudeAgentOptions(hooks=contract_hooks.get_hooks_config()) + """ + return { + "PreToolUse": [{"hooks": [self.pre_tool_use]}], + "PostToolUse": [{"hooks": [self.post_tool_use]}], + } + + def track_result(self, result_message: Any) -> None: + """Extract cost and token usage from a ResultMessage. + + Call this after the agent run completes: + async for message in query(prompt="..."): + if hasattr(message, 'total_cost_usd'): + hooks.track_result(message) + """ + cost = getattr(result_message, "total_cost_usd", None) + if cost is not None and cost > 0: + self._enforcer.add_cost(cost) + + usage = getattr(result_message, "usage", None) + if isinstance(usage, dict): + input_tokens = usage.get("input_tokens", 0) + output_tokens = usage.get("output_tokens", 0) + total = input_tokens + output_tokens + if total > 0: + self._enforcer.add_tokens(total) diff --git a/src/agent_contracts/adapters/openai_agents.py b/src/agent_contracts/adapters/openai_agents.py new file mode 100644 index 0000000..fad45d0 --- /dev/null +++ b/src/agent_contracts/adapters/openai_agents.py @@ -0,0 +1,139 @@ +"""OpenAI Agents SDK adapter — contract enforcement via RunHooks. + +Usage (3 lines): + from agent_contracts.adapters.openai_agents import ContractRunHooks + hooks = ContractRunHooks.from_file("contract.yaml") + result = await Runner.run(agent, "prompt", run_hooks=[hooks]) + +Requires: pip install aicontracts[openai] + +Honest limitation: on_tool_start fires AFTER the LLM has already decided +to call the tool and spent reasoning tokens. Blocking here prevents +execution but not the token cost of the decision. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, List, Optional, Union + +from agent_contracts.enforcer import ContractEnforcer, ContractViolation +from agent_contracts.loader import load_contract +from agent_contracts.types import Contract +from agent_contracts.violations import ViolationEvent + +try: + from openai_agents import RunHooks +except ImportError: + # Stub so the module can be imported without openai-agents + class RunHooks: # type: ignore[no-redef] + """Stub for when openai-agents is not installed.""" + pass + + +class ContractRunHooks(RunHooks): # type: ignore[misc] + """OpenAI Agents SDK RunHooks that enforces an agent contract. + + Intercepts tool calls for effect gating and budget tracking. + Tracks token usage from LLM responses. + Evaluates postconditions on agent completion. + """ + + def __init__( + self, + contract: Contract, + *, + violation_destination: str = "stdout", + violation_callback: Optional[Any] = None, + raise_on_violation: bool = True, + ) -> None: + self._enforcer = ContractEnforcer( + contract, + violation_destination=violation_destination, + violation_callback=violation_callback, + ) + self._raise_on_violation = raise_on_violation + + @classmethod + def from_file( + cls, + path: Union[str, Path], + *, + violation_destination: str = "stdout", + raise_on_violation: bool = True, + ) -> "ContractRunHooks": + """Create hooks from a contract YAML file.""" + contract = load_contract(path) + return cls( + contract, + violation_destination=violation_destination, + raise_on_violation=raise_on_violation, + ) + + @property + def enforcer(self) -> ContractEnforcer: + return self._enforcer + + @property + def violations(self) -> List[ViolationEvent]: + return self._enforcer.violations + + async def on_tool_start( + self, context: Any, agent: Any, tool: Any + ) -> None: + """Called before a tool executes — enforce effect authorization and budget. + + Raises ContractViolation to prevent unauthorized tool execution. + Note: LLM reasoning tokens for this tool call are already spent. + """ + tool_name = getattr(tool, "name", str(tool)) + try: + self._enforcer.check_tool_call(tool_name) + except ContractViolation: + if self._raise_on_violation: + raise + + async def on_tool_end( + self, context: Any, agent: Any, tool: Any, result: str + ) -> None: + """Called after a tool executes.""" + pass + + async def on_llm_start( + self, context: Any, agent: Any, system_prompt: Any, input_items: Any + ) -> None: + """Called before an LLM invocation.""" + pass + + async def on_llm_end( + self, context: Any, agent: Any, response: Any + ) -> None: + """Called after an LLM invocation — track token usage.""" + # Extract usage from the response object + usage = getattr(response, "usage", None) + if usage is not None: + total = getattr(usage, "total_tokens", 0) + if total and total > 0: + self._enforcer.add_tokens(total) + + async def on_agent_start( + self, context: Any, agent: Any + ) -> None: + """Called when an agent starts executing.""" + pass + + async def on_agent_end( + self, context: Any, agent: Any, output: Any + ) -> None: + """Called when an agent finishes — evaluate postconditions.""" + try: + self._enforcer.evaluate_postconditions(output) + except ContractViolation: + if self._raise_on_violation: + raise + + async def on_handoff( + self, context: Any, input: Any + ) -> None: + """Called during agent handoffs — observe, don't enforce yet.""" + pass diff --git a/src/agent_contracts/enforcer.py b/src/agent_contracts/enforcer.py index fd9dba2..941dbc3 100644 --- a/src/agent_contracts/enforcer.py +++ b/src/agent_contracts/enforcer.py @@ -22,7 +22,10 @@ from agent_contracts.loader import load_contract from agent_contracts.postconditions import ( PostconditionResult, + PreconditionError, + PreconditionResult, evaluate_postconditions, + evaluate_preconditions, ) from agent_contracts.types import Contract from agent_contracts.violations import ViolationEmitter, ViolationEvent @@ -77,6 +80,31 @@ def violations(self) -> List[ViolationEvent]: def warnings(self) -> List[str]: return list(self._warnings) + # --- Precondition evaluation --- + + def check_preconditions(self, input_data: Any) -> List[PreconditionResult]: + """Evaluate preconditions against input data before agent runs. + + Raises ContractViolation if any precondition fails. + Returns empty list if no preconditions are defined. + """ + if not self._contract.preconditions: + return [] + try: + return evaluate_preconditions( + self._contract.preconditions, input_data, raise_on_failure=True + ) + except PreconditionError as e: + event = self._emitter.create_event( + contract_id=self._contract.identity.name, + contract_version=self._contract.identity.version, + violated_clause=f"inputs.preconditions.{e.precondition.name}", + evidence={"check": e.precondition.check}, + severity="critical", + enforcement="blocked", + ) + raise ContractViolation(str(e), event=event) from e + # --- Input validation --- def validate_input(self, input_data: Any) -> List[str]: @@ -268,6 +296,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: f"Input validation failed: {errors}" ) + # Pre: evaluate preconditions + if args and contract.preconditions: + enforcer.check_preconditions(args[0]) + result = fn(*args, **kwargs) # Post: validate output diff --git a/src/agent_contracts/postconditions.py b/src/agent_contracts/postconditions.py index 63625c7..48d4a17 100644 --- a/src/agent_contracts/postconditions.py +++ b/src/agent_contracts/postconditions.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional -from agent_contracts.types import PostconditionDef +from agent_contracts.types import PostconditionDef, PreconditionDef # Safe operators for expression evaluation _OPERATORS = { @@ -194,6 +194,50 @@ def evaluate_expression(check: str, context: Dict[str, Any]) -> bool: return bool(val) +class PreconditionError(Exception): + """Raised when a precondition fails (input rejected before agent runs).""" + + def __init__(self, precondition: PreconditionDef, input_data: Any) -> None: + self.precondition = precondition + self.input_data = input_data + super().__init__( + f"Precondition '{precondition.name}' failed: {precondition.check}" + ) + + +@dataclass +class PreconditionResult: + """Result of evaluating a precondition.""" + + precondition: PreconditionDef + passed: bool + + +def evaluate_preconditions( + preconditions: List[PreconditionDef], + input_data: Any, + *, + raise_on_failure: bool = True, +) -> List[PreconditionResult]: + """Evaluate all preconditions against input data. + + Preconditions use the same expression evaluator as postconditions. + Context key is 'input' instead of 'output'. + + If raise_on_failure is True, raises PreconditionError on first failure. + """ + context: Dict[str, Any] = {"input": input_data} + + results: List[PreconditionResult] = [] + for pc in preconditions: + passed = evaluate_expression(pc.check, context) + results.append(PreconditionResult(precondition=pc, passed=passed)) + if not passed and raise_on_failure: + raise PreconditionError(pc, input_data) + + return results + + @dataclass class PostconditionResult: """Result of evaluating a postcondition.""" diff --git a/tests/test_adapters/test_claude_agent.py b/tests/test_adapters/test_claude_agent.py new file mode 100644 index 0000000..43e5b71 --- /dev/null +++ b/tests/test_adapters/test_claude_agent.py @@ -0,0 +1,115 @@ +"""Tests for Claude Agent SDK adapter.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any, Dict +from unittest.mock import MagicMock + +import pytest +import yaml + +from agent_contracts.adapters.claude_agent import ContractHooks + + +@pytest.fixture +def hooks(tmp_path: Path, tier1_data: Dict[str, Any]) -> ContractHooks: + p = tmp_path / "contract.yaml" + p.write_text(yaml.dump(tier1_data, sort_keys=False), encoding="utf-8") + return ContractHooks.from_file(p, violation_destination="callback") + + +def run_async(coro): + """Helper to run async tests.""" + return asyncio.run(coro) + + +class TestContractHooks: + def test_from_file(self, hooks) -> None: + assert hooks.enforcer is not None + assert hooks.enforcer.contract.tier == 1 + + def test_authorized_tool_allows(self, hooks) -> None: + result = run_async(hooks.pre_tool_use({ + "tool_name": "search", + "tool_input": {}, + "hook_event_name": "PreToolUse", + })) + assert result == {} + + def test_unauthorized_tool_denies(self, hooks) -> None: + result = run_async(hooks.pre_tool_use({ + "tool_name": "delete_all", + "tool_input": {}, + "hook_event_name": "PreToolUse", + })) + assert "hookSpecificOutput" in result + output = result["hookSpecificOutput"] + assert output["permissionDecision"] == "deny" + assert "delete_all" in output["permissionDecisionReason"] + assert output["hookEventName"] == "PreToolUse" + + def test_deny_includes_contract_name(self, hooks) -> None: + result = run_async(hooks.pre_tool_use({ + "tool_name": "bad", + "tool_input": {}, + "hook_event_name": "PreToolUse", + })) + reason = result["hookSpecificOutput"]["permissionDecisionReason"] + assert "test-agent" in reason + + def test_tool_budget_tracking(self, tmp_path, tier1_data) -> None: + tier1_data["resources"]["budgets"]["max_tool_calls"] = 2 + p = tmp_path / "contract.yaml" + p.write_text(yaml.dump(tier1_data, sort_keys=False), encoding="utf-8") + h = ContractHooks.from_file(p) + + # First two allowed + run_async(h.pre_tool_use({"tool_name": "search", "tool_input": {}})) + run_async(h.pre_tool_use({"tool_name": "database.read", "tool_input": {}})) + + # Third denied (budget exceeded — returns deny, not exception) + result = run_async(h.pre_tool_use({"tool_name": "search", "tool_input": {}})) + assert "hookSpecificOutput" in result + assert result["hookSpecificOutput"]["permissionDecision"] == "deny" + + def test_post_tool_use(self, hooks) -> None: + result = run_async(hooks.post_tool_use({ + "tool_name": "search", + "tool_input": {}, + })) + assert result == {} + + def test_get_hooks_config(self, hooks) -> None: + config = hooks.get_hooks_config() + assert "PreToolUse" in config + assert "PostToolUse" in config + assert len(config["PreToolUse"]) == 1 + assert hooks.pre_tool_use in config["PreToolUse"][0]["hooks"] + + def test_track_result_cost(self, hooks) -> None: + msg = MagicMock() + msg.total_cost_usd = 0.05 + msg.usage = {"input_tokens": 100, "output_tokens": 50} + hooks.track_result(msg) + snapshot = hooks.enforcer.budget_tracker.snapshot() + assert snapshot.cost_usd == 0.05 + assert snapshot.tokens == 150 + + def test_track_result_no_cost(self, hooks) -> None: + msg = MagicMock(spec=[]) # No attributes + hooks.track_result(msg) # Should not raise + + def test_track_result_zero_cost(self, hooks) -> None: + msg = MagicMock() + msg.total_cost_usd = 0 + msg.usage = {} + hooks.track_result(msg) + + def test_violations_accumulated(self, hooks) -> None: + run_async(hooks.pre_tool_use({ + "tool_name": "unauthorized", + "tool_input": {}, + })) + assert len(hooks.violations) == 1 diff --git a/tests/test_adapters/test_openai_agents.py b/tests/test_adapters/test_openai_agents.py new file mode 100644 index 0000000..3bc88f4 --- /dev/null +++ b/tests/test_adapters/test_openai_agents.py @@ -0,0 +1,103 @@ +"""Tests for OpenAI Agents SDK adapter.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any, Dict +from unittest.mock import MagicMock + +import pytest +import yaml + +from agent_contracts.adapters.openai_agents import ContractRunHooks +from agent_contracts.enforcer import ContractViolation + + +@pytest.fixture +def hooks(tmp_path: Path, tier1_data: Dict[str, Any]) -> ContractRunHooks: + p = tmp_path / "contract.yaml" + p.write_text(yaml.dump(tier1_data, sort_keys=False), encoding="utf-8") + return ContractRunHooks.from_file( + p, violation_destination="callback", raise_on_violation=True + ) + + +def run_async(coro): + """Helper to run async tests.""" + return asyncio.run(coro) + + +class TestContractRunHooks: + def test_from_file(self, hooks) -> None: + assert hooks.enforcer is not None + assert hooks.enforcer.contract.tier == 1 + + def test_authorized_tool_passes(self, hooks) -> None: + tool = MagicMock() + tool.name = "search" + run_async(hooks.on_tool_start(None, None, tool)) + + def test_unauthorized_tool_raises(self, hooks) -> None: + tool = MagicMock() + tool.name = "delete_all" + with pytest.raises(ContractViolation, match="not authorized"): + run_async(hooks.on_tool_start(None, None, tool)) + + def test_tool_budget_tracking(self, tmp_path, tier1_data) -> None: + tier1_data["resources"]["budgets"]["max_tool_calls"] = 2 + p = tmp_path / "contract.yaml" + p.write_text(yaml.dump(tier1_data, sort_keys=False), encoding="utf-8") + h = ContractRunHooks.from_file(p, raise_on_violation=True) + + tool = MagicMock() + tool.name = "search" + run_async(h.on_tool_start(None, None, tool)) + tool.name = "database.read" + run_async(h.on_tool_start(None, None, tool)) + + tool.name = "search" + with pytest.raises(ContractViolation): + run_async(h.on_tool_start(None, None, tool)) + + def test_token_tracking_from_llm_end(self, hooks) -> None: + response = MagicMock() + response.usage.total_tokens = 500 + run_async(hooks.on_llm_end(None, None, response)) + assert hooks.enforcer.budget_tracker.snapshot().tokens == 500 + + def test_token_tracking_no_usage(self, hooks) -> None: + response = MagicMock(spec=[]) # No usage attr + run_async(hooks.on_llm_end(None, None, response)) + + def test_postconditions_on_agent_end(self, hooks) -> None: + run_async(hooks.on_agent_end(None, None, {"result": "data"})) + + def test_violations_accumulated(self, hooks) -> None: + tool = MagicMock() + tool.name = "bad_tool" + try: + run_async(hooks.on_tool_start(None, None, tool)) + except ContractViolation: + pass + assert len(hooks.violations) == 1 + + def test_non_raising_mode(self, tmp_path, tier1_data) -> None: + p = tmp_path / "contract.yaml" + p.write_text(yaml.dump(tier1_data, sort_keys=False), encoding="utf-8") + h = ContractRunHooks.from_file(p, raise_on_violation=False) + tool = MagicMock() + tool.name = "unauthorized" + run_async(h.on_tool_start(None, None, tool)) # Should not raise + + def test_on_tool_end(self, hooks) -> None: + run_async(hooks.on_tool_end(None, None, None, "result")) + + def test_on_agent_start(self, hooks) -> None: + run_async(hooks.on_agent_start(None, None)) + + def test_on_handoff(self, hooks) -> None: + run_async(hooks.on_handoff(None, None)) + + def test_on_llm_start(self, hooks) -> None: + run_async(hooks.on_llm_start(None, None, None, None)) diff --git a/tests/test_preconditions.py b/tests/test_preconditions.py new file mode 100644 index 0000000..a72aa8b --- /dev/null +++ b/tests/test_preconditions.py @@ -0,0 +1,126 @@ +"""Tests for precondition evaluation.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest +import yaml + +from agent_contracts.enforcer import ContractEnforcer, ContractViolation +from agent_contracts.loader import load_contract +from agent_contracts.postconditions import ( + PreconditionError, + evaluate_preconditions, +) +from agent_contracts.types import PreconditionDef + + +class TestEvaluatePreconditions: + def test_passing_precondition(self) -> None: + pcs = [PreconditionDef(name="has_query", check="input is not None")] + results = evaluate_preconditions(pcs, {"query": "hello"}) + assert len(results) == 1 + assert results[0].passed is True + + def test_failing_precondition_raises(self) -> None: + pcs = [PreconditionDef(name="has_query", check="input is None")] + with pytest.raises(PreconditionError, match="has_query"): + evaluate_preconditions(pcs, {"query": "hello"}) + + def test_failing_precondition_no_raise(self) -> None: + pcs = [PreconditionDef(name="has_query", check="input is None")] + results = evaluate_preconditions(pcs, {"query": "hello"}, raise_on_failure=False) + assert len(results) == 1 + assert results[0].passed is False + + def test_multiple_preconditions(self) -> None: + pcs = [ + PreconditionDef(name="not_none", check="input is not None"), + PreconditionDef(name="has_key", check='input.type == "search"'), + ] + results = evaluate_preconditions( + pcs, {"type": "search"}, raise_on_failure=False + ) + assert all(r.passed for r in results) + + def test_first_failure_stops_on_raise(self) -> None: + pcs = [ + PreconditionDef(name="fails", check="input is None"), + PreconditionDef(name="never_reached", check="input is not None"), + ] + with pytest.raises(PreconditionError, match="fails"): + evaluate_preconditions(pcs, "data") + + def test_context_key_is_input(self) -> None: + pcs = [PreconditionDef(name="check_field", check='input.role == "admin"')] + results = evaluate_preconditions( + pcs, {"role": "admin"}, raise_on_failure=False + ) + assert results[0].passed is True + + def test_precondition_error_attributes(self) -> None: + pc = PreconditionDef(name="test_pc", check="input is None") + err = PreconditionError(pc, {"data": 1}) + assert err.precondition is pc + assert err.input_data == {"data": 1} + + +class TestEnforcerPreconditions: + @pytest.fixture + def contract_with_preconditions(self, tmp_path: Path) -> Any: + data = { + "agent_contract": "0.1.0", + "identity": {"name": "test-agent", "version": "1.0.0"}, + "contract": { + "postconditions": [ + {"name": "has_output", "check": "output is not None"} + ] + }, + "inputs": { + "preconditions": [ + {"name": "has_query", "check": "input is not None"}, + {"name": "valid_type", "check": 'input.type == "search"'}, + ] + }, + } + p = tmp_path / "contract.yaml" + p.write_text(yaml.dump(data, sort_keys=False), encoding="utf-8") + return load_contract(p) + + def test_check_preconditions_pass(self, contract_with_preconditions) -> None: + enforcer = ContractEnforcer(contract_with_preconditions) + results = enforcer.check_preconditions({"type": "search"}) + assert len(results) == 2 + assert all(r.passed for r in results) + + def test_check_preconditions_fail(self, contract_with_preconditions) -> None: + enforcer = ContractEnforcer(contract_with_preconditions) + with pytest.raises(ContractViolation, match="valid_type"): + enforcer.check_preconditions({"type": "delete"}) + + def test_check_preconditions_emits_violation(self, contract_with_preconditions) -> None: + enforcer = ContractEnforcer( + contract_with_preconditions, violation_destination="callback" + ) + with pytest.raises(ContractViolation): + enforcer.check_preconditions({"type": "delete"}) + assert len(enforcer.violations) == 1 + assert "preconditions" in enforcer.violations[0].violated_clause + + def test_no_preconditions_returns_empty(self, tmp_path) -> None: + data = { + "agent_contract": "0.1.0", + "identity": {"name": "test", "version": "1.0.0"}, + "contract": { + "postconditions": [ + {"name": "p", "check": "output is not None"} + ] + }, + } + p = tmp_path / "contract.yaml" + p.write_text(yaml.dump(data, sort_keys=False), encoding="utf-8") + contract = load_contract(p) + enforcer = ContractEnforcer(contract) + assert enforcer.check_preconditions({"anything": True}) == []