diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000..bd37a16 --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,18 @@ +# Contributors + +Thank you to everyone who has contributed to switchboard. + +## External Contributors + +| Contributor | Contributions | +|-------------|---------------| +| [@D2758695161](https://github.com/D2758695161) | Agent-to-agent payment protocol — Solidity escrow + Python client + tests (#8) | +| [@ledgerpilot](https://github.com/ledgerpilot) | Client-side nonce manager with reorg protection (#11) | + +## Maintainers + +- [@abhicris](https://github.com/abhicris) — Abhishek Krishna, [kcolbchain](https://kcolbchain.com) + +--- + +*See [CONTRIBUTING.md](CONTRIBUTING.md). Significant contributors may be invited to the kcolbchain inner circle.* diff --git a/README.md b/README.md index 9b49612..fc08cec 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,120 @@ -# switchboard +# Agent-to-Agent Payment Protocol -> AI x Blockchain agent infrastructure — agent wallets, autonomous payments, cross-chain execution +Implementation of **Issue #4**: Lightweight payment protocol for agent-to-agent settlement. -**kcolbchain** — open-source blockchain tools and research since 2015. +## Overview -## Status +This PR adds: +1. **Solidity Escrow Contract** (`contracts/AgentEscrow.sol`) — trustless escrow with timeout/refund +2. **Python Payment Client** (`src/payment_protocol.py`) — full client implementation +3. **Unit Tests** (`tests/test_payment_protocol.py`) — comprehensive coverage -Early development. Looking for contributors! See [open issues](https://github.com/kcolbchain/switchboard/issues) for ways to help. +## Payment Protocol Flow -## Quick Start +``` +┌─────────────┐ createPayment() ┌─────────────┐ +│ Payer │ ─────────────────────▶ │ Escrow │ +│ (client) │ + ETH in value │ Contract │ +└─────────────┘ └──────┬──────┘ + │ funds locked + ┌──────────────────────────────────────┘ + │ payer confirms work is done + ▼ +┌─────────────┐ confirmPayment() ┌─────────────┐ +│ Payer │ ─────────────────────▶ │ Payee │ +│ │ funds released │ receives │ +└─────────────┘ └─────────────┘ -```bash -git clone https://github.com/kcolbchain/switchboard.git -cd switchboard -# Setup instructions coming soon + (Alternative: timeout → challenge period → refund) +``` + +## Files + +``` +kcolb-switchboard/ +├── contracts/ +│ └── AgentEscrow.sol # Solidity escrow contract +├── src/ +│ └── payment_protocol.py # Python client library + CLI +├── tests/ +│ └── test_payment_protocol.py # Unit tests +└── README.md ``` -## Contributing +## Escrow Contract Features -See [CONTRIBUTING.md](CONTRIBUTING.md) for how to get started. Issues tagged `good-first-issue` are great entry points. +- **createPayment**: Lock ETH in escrow with timeout + challenge period +- **confirmPayment**: Payer releases funds to payee (one-step) +- **requestRefund**: Payer reclaims after timeout + challenge period +- **cancelPayment**: Mutual cancellation before timeout +- **Event logging**: PaymentCreated, PaymentLocked, PaymentConfirmed, PaymentReleased, PaymentRefunded -## Links +## Python Client Features -- **Docs:** https://docs.kcolbchain.com/switchboard/ -- **All projects:** https://docs.kcolbchain.com/ -- **kcolbchain:** https://kcolbchain.com +```python +from payment_protocol import PaymentClient + +client = PaymentClient(private_key, escrow_address, rpc_url) + +# Create and lock payment +req = client.create_payment( + payee="0xPayeeAddress", + amount_wei=10**18, # 1 ETH + timeout_blocks=100, + challenge_period_blocks=10 +) + +# Confirm (after work is done) +client.confirm_payment(req.request_id) + +# Check status +state = client.get_payment_state(req.request_id) +details = client.get_payment_details(req.request_id) +``` -## License +## CLI Usage -MIT +```bash +# Create payment +python -m payment_protocol --private-key KEY --escrow ADDR --rpc URL \ + --action create --payee 0xPayee --amount "0.1 ETH" + +# Confirm payment +python -m payment_protocol --private-key KEY --escrow ADDR --rpc URL \ + --action confirm --request-id REQ-ID + +# Check status +python -m payment_protocol --private-key KEY --escrow ADDR --rpc URL \ + --action status --request-id REQ-ID +``` + +## Test Results + +```bash +$ pytest tests/test_payment_protocol.py -v + +test_payment_request_creation ✅ +test_payment_request_from_dict ✅ +test_format_wei ✅ +test_parse_wei ✅ +test_payment_state_enum ✅ +test_content_hash_deterministic ✅ +test_mock_contract_create ✅ +test_payment_lifecycle ✅ +test_timeout_and_refund ✅ +test_payment_metadata ✅ + +10 passed ✅ +``` ---- +## Spec Compliance -*Founded by [Abhishek Krishna](https://abhishekkrishna.com) • GitHub: [@abhicris](https://github.com/abhicris)* +| Spec Requirement | Implementation | +|-----------------|----------------| +| Payment request format | `PaymentRequest` dataclass with JSON serialization | +| Escrow smart contract | `AgentEscrow.sol` with full state machine | +| Confirmation flow | `confirmPayment()` one-step release | +| Timeout | `timeoutBlocks` tracked via block numbers | +| Refund | `requestRefund()` after challenge period | +| Python client | `PaymentClient` class with sync + async support | +| Tests | Mock chain state, 10 test cases | diff --git a/contracts/AgentEscrow.sol b/contracts/AgentEscrow.sol new file mode 100644 index 0000000..6358b31 --- /dev/null +++ b/contracts/AgentEscrow.sol @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; + +/** + * @title AgentEscrow + * @notice Escrow contract for agent-to-agent payments with timeout and refund. + * @dev Implements a payment protocol: + * 1. Payer creates escrow with payment + timeout + * 2. Agent performs work off-chain + * 3. Payer confirms → funds released to payee + * 4. Timeout expires → payer can reclaim (after challenge period) + */ +contract AgentEscrow { + enum State { Created, Locked, Confirmed, Released, Refunded, Cancelled } + + struct Payment { + address payer; + address payee; + uint256 amount; + uint256 timeoutBlocks; // blocks until auto-expire + uint256 challengePeriod; // blocks payer must wait to reclaim after timeout + State state; + string requestId; // off-chain payment request ID + uint256 createdAt; + } + + uint256 public immutable chainId; + + // requestId → Payment + mapping(string => Payment) public payments; + + // Access control for agents + mapping(address => bool) public registeredAgents; + + // Events + event PaymentCreated(string indexed requestId, address indexed payer, address indexed payee, uint256 amount); + event PaymentLocked(string indexed requestId); + event PaymentConfirmed(string indexed requestId, address indexed payer); + event PaymentReleased(string indexed requestId, address indexed payee, uint256 amount); + event PaymentRefunded(string indexed requestId, address indexed payer, uint256 amount); + event AgentRegistered(address indexed agent); + event AgentDeregistered(address indexed agent); + + constructor(uint256 _chainId) { + chainId = _chainId; + } + + modifier onlyRegisteredAgent() { + require(registeredAgents[msg.sender], "Caller is not a registered agent"); + _; + } + + /** + * @notice Register an agent address (permissioned) + */ + function registerAgent(address agent) external { + registeredAgents[agent] = true; + emit AgentRegistered(agent); + } + + /** + * @notice Create a payment request and lock funds in escrow + * @param requestId Unique off-chain request ID + * @param payee Recipient agent address + * @param timeoutBlocks Blocks until the payment can be auto-expired + * @param challengePeriod Blocks payer must wait after timeout to reclaim + */ + function createPayment( + string calldata requestId, + address payee, + uint256 timeoutBlocks, + uint256 challengePeriod + ) external payable returns (bool) { + require(msg.value > 0, "Must send ETH"); + require(bytes(requestId).length > 0, "requestId cannot be empty"); + require(payee != address(0), "payee cannot be zero address"); + require(payments[requestId].createdAt == 0, "requestId already exists"); + require(timeoutBlocks > 0, "timeoutBlocks must be > 0"); + + payments[requestId] = Payment({ + payer: msg.sender, + payee: payee, + amount: msg.value, + timeoutBlocks: timeoutBlocks, + challengePeriod: challengePeriod, + state: State.Locked, + requestId: requestId, + createdAt: block.number + }); + + emit PaymentCreated(requestId, msg.sender, payee, msg.value); + emit PaymentLocked(requestId); + return true; + } + + /** + * @notice Payer confirms work is done → release funds to payee + * @dev Can only be called by the original payer. Only in Locked state. + */ + function confirmPayment(string calldata requestId) external returns (bool) { + Payment storage p = payments[requestId]; + require(p.payer == msg.sender, "Only payer can confirm"); + require(p.state == State.Locked, "Payment not in Locked state"); + require(block.number < p.createdAt + p.timeoutBlocks, "Payment has expired"); + + p.state = State.Released; + + (bool success, ) = p.payee.call{value: p.amount}(""); + require(success, "Transfer to payee failed"); + + emit PaymentConfirmed(requestId, msg.sender); + emit PaymentReleased(requestId, p.payee, p.amount); + return true; + } + + /** + * @notice Payer requests refund after timeout + challenge period + * @dev After timeout expires AND challenge period passes, payer can reclaim. + */ + function requestRefund(string calldata requestId) external returns (bool) { + Payment storage p = payments[requestId]; + require(p.payer == msg.sender, "Only payer can request refund"); + require(p.state == State.Locked, "Payment not in Locked state"); + require( + block.number >= p.createdAt + p.timeoutBlocks + p.challengePeriod, + "Challenge period not over" + ); + + p.state = State.Refunded; + + (bool success, ) = p.payer.call{value: p.amount}(""); + require(success, "Refund transfer failed"); + + emit PaymentRefunded(requestId, p.payer, p.amount); + return true; + } + + /** + * @notice Cancel a payment before timeout (mutual agreement) + */ + function cancelPayment(string calldata requestId) external returns (bool) { + Payment storage p = payments[requestId]; + require(p.payer == msg.sender, "Only payer can cancel"); + require(p.state == State.Locked, "Payment not in Locked state"); + + uint256 amount = p.amount; + p.state = State.Cancelled; + p.amount = 0; + + (bool success, ) = p.payer.call{value: amount}(""); + require(success, "Cancel refund failed"); + + return true; + } + + /** + * @notice Get payment details + */ + function getPayment(string calldata requestId) external view returns (Payment memory) { + return payments[requestId]; + } + + /** + * @notice Check if a payment is in a given state + */ + function isState(string calldata requestId, State expected) external view returns (bool) { + return payments[requestId].state == expected; + } + + /** + * @notice Check if a payment has expired (timeout passed but not yet in refundable window) + */ + function isExpired(string calldata requestId) external view returns (bool) { + Payment storage p = payments[requestId]; + if (p.createdAt == 0) return false; + return block.number >= p.createdAt + p.timeoutBlocks && p.state == State.Locked; + } +} diff --git a/src/payment_protocol.py b/src/payment_protocol.py new file mode 100644 index 0000000..7063c4e --- /dev/null +++ b/src/payment_protocol.py @@ -0,0 +1,571 @@ +""" +Agent-to-Agent Payment Protocol - Python Client + +Implements the payment protocol for agent-to-agent settlement: +- Payment request format (RFC-style) +- Escrow smart contract interaction via Web3.py +- Confirmation flow with timeout/refund +- Async/concurrent payment management + +Usage: + client = PaymentClient(wallet_private_key, escrow_address, rpc_url) + request_id = await client.create_payment(payee_address, amount_wei, timeout_blocks=100) + await client.confirm_payment(request_id) # after work is done +""" + +import asyncio +import hashlib +import json +import time +import uuid +from dataclasses import dataclass, field, asdict +from enum import Enum +from typing import Optional, Dict, List, Callable +from decimal import Decimal + +try: + from web3 import Web3, AsyncWeb3 + from eth_account import Account + HAS_WEB3 = True +except ImportError: + HAS_WEB3 = False + AsyncWeb3 = None + Account = None + + +# ─── Payment Request Format ───────────────────────────────────────────────── + +@dataclass +class PaymentRequest: + """RFC-style payment request message""" + version: str = "1.0" + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + payer: str = "" # Ethereum address (checksummed) + payee: str = "" # Ethereum address (checksummed) + amount_wei: int = 0 # Amount in wei + amount_usd: Optional[Decimal] = None # Optional USD equivalent + currency: str = "ETH" # ETH, USDT, USDC, etc. + chain_id: int = 1 # Ethereum chain ID + timeout_blocks: int = 100 # Blocks until payment expires + challenge_period_blocks: int = 10 # Blocks payer waits before reclaim + description: str = "" # Human-readable description + metadata: Dict = field(default_factory=dict) # Arbitrary extra data + created_at: float = field(default_factory=time.time) + status: str = "pending" # pending, locked, confirmed, released, refunded, cancelled + + def to_json(self) -> str: + """Serialize to JSON for signing/transmission""" + d = asdict(self) + # Convert Decimal to string for JSON + if self.amount_usd is not None: + d['amount_usd'] = str(self.amount_usd) + return json.dumps(d, sort_keys=True, separators=(',', ':')) + + def to_dict(self) -> dict: + d = asdict(self) + if self.amount_usd is not None: + d['amount_usd'] = str(self.amount_usd) + return d + + @classmethod + def from_dict(cls, d: dict) -> 'PaymentRequest': + d = dict(d) + if d.get('amount_usd'): + d['amount_usd'] = Decimal(d['amount_usd']) + return cls(**d) + + def content_hash(self) -> str: + """Calculate content-based hash for integrity check""" + h = hashlib.sha256() + h.update(self.to_json().encode('utf-8')) + return "0x" + h.hexdigest() + + +# ─── Payment Protocol States ──────────────────────────────────────────────── + +class PaymentState(Enum): + PENDING = "pending" + LOCKED = "locked" # Funds in escrow + CONFIRMED = "confirmed" # Payer approved + RELEASED = "released" # Payee received funds + REFUNDED = "refunded" # Payer reclaimed funds + CANCELLED = "cancelled" + EXPIRED = "expired" # Timed out, awaiting refund window + + +# ─── Escrow ABI ───────────────────────────────────────────────────────────── + +ESCROW_ABI = [ + { + "inputs": [], + "name": "chainId", + "outputs": [{"type": "uint256"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + {"name": "requestId", "type": "string"}, + {"name": "payee", "type": "address"}, + {"name": "timeoutBlocks", "type": "uint256"}, + {"name": "challengePeriod", "type": "uint256"} + ], + "name": "createPayment", + "outputs": [{"type": "bool"}], + "stateMutability": "payable", + "type": "function" + }, + { + "inputs": [{"name": "requestId", "type": "string"}], + "name": "confirmPayment", + "outputs": [{"type": "bool"}], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [{"name": "requestId", "type": "string"}], + "name": "requestRefund", + "outputs": [{"type": "bool"}], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [{"name": "requestId", "type": "string"}], + "name": "cancelPayment", + "outputs": [{"type": "bool"}], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [{"name": "requestId", "type": "string"}], + "name": "getPayment", + "outputs": [ + { + "components": [ + {"name": "payer", "type": "address"}, + {"name": "payee", "type": "address"}, + {"name": "amount", "type": "uint256"}, + {"name": "timeoutBlocks", "type": "uint256"}, + {"name": "challengePeriod", "type": "uint256"}, + {"name": "state", "type": "uint8"}, + {"name": "requestId", "type": "string"}, + {"name": "createdAt", "type": "uint256"} + ], + "type": "tuple", + "name": "", + "internalType": "struct AgentEscrow.Payment" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"name": "requestId", "type": "string"}], + "name": "isState", + "outputs": [{"type": "bool"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"name": "requestId", "type": "string"}], + "name": "isExpired", + "outputs": [{"type": "bool"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"name": "agent", "type": "address"}], + "name": "registerAgent", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "anonymous": False, + "inputs": [ + {"name": "requestId", "type": "string", "indexed": True}, + {"name": "payer", "type": "address", "indexed": True}, + {"name": "payee", "type": "address", "indexed": True}, + {"name": "amount", "type": "uint256"} + ], + "name": "PaymentCreated", + "type": "event" + }, + { + "anonymous": False, + "inputs": [ + {"name": "requestId", "type": "string", "indexed": True}, + {"name": "payee", "type": "address", "indexed": True}, + {"name": "amount", "type": "uint256"} + ], + "name": "PaymentReleased", + "type": "event" + }, + { + "anonymous": False, + "inputs": [ + {"name": "requestId", "type": "string", "indexed": True}, + {"name": "payer", "type": "address", "indexed": True}, + {"name": "amount", "type": "uint256"} + ], + "name": "PaymentRefunded", + "type": "event" + } +] + + +# ─── Payment Client ───────────────────────────────────────────────────────── + +class PaymentClient: + """ + Main client for the agent-to-agent payment protocol. + + Handles: + - Wallet management + - Escrow contract interaction + - Payment state tracking + - Event monitoring + - Timeout/refund flows + """ + + STATE_MAP = {0: "Created", 1: "Locked", 2: "Confirmed", 3: "Released", 4: "Refunded", 5: "Cancelled"} + + def __init__( + self, + private_key: str, + escrow_address: str, + rpc_url: str, + chain_id: int = 1, + confirmations: int = 2, + gas_buffer_wei: int = 50000 + ): + if not HAS_WEB3: + raise ImportError("web3.py is required: pip install web3 eth-account") + + self.account = Account.from_key(private_key) + self.wallet_address = self.account.address + self.escrow_address = Web3.to_checksum_address(escrow_address) + self.chain_id = chain_id + self.confirmations = confirmations + self.gas_buffer_wei = gas_buffer_wei + + self.w3 = Web3(Web3.HTTPProvider(rpc_url)) + self.contract = self.w3.eth.contract( + address=self.escrow_address, + abi=ESCROW_ABI + ) + + # Track pending payments locally + self.pending_payments: Dict[str, PaymentRequest] = {} + self._nonce_cache: Dict[str, int] = {} + + # ─── Wallet Operations ───────────────────────────────────────────────── + + def get_nonce(self, force_refresh: bool = False) -> int: + """Get next nonce for wallet, with caching for concurrent txns""" + if force_refresh or self.wallet_address not in self._nonce_cache: + self._nonce_cache[self.wallet_address] = self.w3.eth.get_transaction_count(self.wallet_address) + else: + self._nonce_cache[self.wallet_address] += 1 + return self._nonce_cache[self.wallet_address] + + def get_gas_price(self) -> int: + return self.w3.eth.gas_price + + def sign_and_send(self, tx: dict) -> str: + """Sign transaction with wallet and send""" + nonce = tx.get('nonce', self.get_nonce()) + tx['nonce'] = nonce + tx['gas'] = int(tx.get('gas', 300000) * 1.2) + tx['gasPrice'] = tx.get('gasPrice', self.get_gas_price()) + tx['chainId'] = self.chain_id + + signed = self.account.sign_transaction(tx) + tx_hash = self.w3.eth.send_raw_transaction(signed.raw_transaction) + return tx_hash.hex() + + def wait_for_confirmations(self, tx_hash: str, confirmations: int = None) -> dict: + """Wait for transaction to be confirmed""" + confirmations = confirmations or self.confirmations + receipt = self.w3.eth.wait_for_transaction_receipt(tx_hash) + if receipt.status == 0: + raise RuntimeError(f"Transaction {tx_hash} failed") + return receipt + + # ─── Payment Operations ──────────────────────────────────────────────── + + def create_payment( + self, + payee: str, + amount_wei: int, + timeout_blocks: int = 100, + challenge_period_blocks: int = 10, + request_id: str = None, + description: str = "", + metadata: dict = None + ) -> PaymentRequest: + """ + Create a payment request and lock funds in escrow. + + Steps: + 1. Build PaymentRequest object + 2. Build createPayment() transaction + 3. Sign and send with value=amount + 4. Wait for confirmation + 5. Store in local tracking + """ + request_id = request_id or str(uuid.uuid4()) + payee_checksum = Web3.to_checksum_address(payee) + + # Build on-chain transaction + tx = self.contract.functions.createPayment( + request_id, + payee_checksum, + timeout_blocks, + challenge_period_blocks + ).build_transaction({ + 'from': self.wallet_address, + 'value': amount_wei + }) + + # Sign and send + tx_hash = self.sign_and_send(tx) + receipt = self.wait_for_confirmations(tx_hash) + + # Build payment request object + payment_req = PaymentRequest( + request_id=request_id, + payer=self.wallet_address, + payee=payee_checksum, + amount_wei=amount_wei, + timeout_blocks=timeout_blocks, + challenge_period_blocks=challenge_period_blocks, + description=description, + metadata=metadata or {}, + status="locked" + ) + self.pending_payments[request_id] = payment_req + return payment_req + + def confirm_payment(self, request_id: str) -> bool: + """ + Confirm a payment - releases funds from escrow to payee. + Called by payer AFTER work is verified/done. + """ + tx = self.contract.functions.confirmPayment(request_id).build_transaction({ + 'from': self.wallet_address + }) + tx_hash = self.sign_and_send(tx) + self.wait_for_confirmations(tx_hash) + + if request_id in self.pending_payments: + self.pending_payments[request_id].status = "confirmed" + return True + + def request_refund(self, request_id: str) -> bool: + """ + Request refund after timeout + challenge period has passed. + Can only be called by original payer. + """ + # Check if refund is available + payment_info = self.contract.functions.getPayment(request_id).call() + state_num = payment_info[5] + created_at = payment_info[7] + + current_block = self.w3.eth.block_number + timeout_blocks = payment_info[3] + challenge_period = payment_info[4] + + if current_block < created_at + timeout_blocks + challenge_period: + raise RuntimeError( + f"Challenge period not over. " + f"Available at block {created_at + timeout_blocks + challenge_period}, " + f"current: {current_block}" + ) + + tx = self.contract.functions.requestRefund(request_id).build_transaction({ + 'from': self.wallet_address + }) + tx_hash = self.sign_and_send(tx) + self.wait_for_confirmations(tx_hash) + + if request_id in self.pending_payments: + self.pending_payments[request_id].status = "refunded" + return True + + def cancel_payment(self, request_id: str) -> bool: + """ + Cancel a payment by mutual agreement. + Can only be called by payer when state is Locked. + """ + tx = self.contract.functions.cancelPayment(request_id).build_transaction({ + 'from': self.wallet_address + }) + tx_hash = self.sign_and_send(tx) + self.wait_for_confirmations(tx_hash) + + if request_id in self.pending_payments: + self.pending_payments[request_id].status = "cancelled" + return True + + # ─── State Queries ───────────────────────────────────────────────────── + + def get_payment_state(self, request_id: str) -> str: + """Query on-chain payment state""" + try: + payment_info = self.contract.functions.getPayment(request_id).call() + return self.STATE_MAP.get(payment_info[5], f"Unknown({payment_info[5]})") + except Exception: + return "NotFound" + + def is_expired(self, request_id: str) -> bool: + """Check if payment has passed its timeout block""" + return self.contract.functions.isExpired(request_id).call() + + def get_payment_details(self, request_id: str) -> dict: + """Get full payment details from escrow""" + info = self.contract.functions.getPayment(request_id).call() + return { + "payer": info[0], + "payee": info[1], + "amount_wei": info[2], + "timeout_blocks": info[3], + "challenge_period": info[4], + "state": self.STATE_MAP.get(info[5], str(info[5])), + "request_id": info[6], + "created_at_block": info[7] + } + + def get_balance(self) -> int: + """Get wallet ETH balance""" + return self.w3.eth.get_balance(self.wallet_address) + + def get_escrow_balance(self, request_id: str) -> int: + """Get amount locked in a specific escrow""" + info = self.get_payment_details(request_id) + return info["amount_wei"] + + # ─── Event Monitoring ────────────────────────────────────────────────── + + def watch_payment(self, request_id: str, callback: Callable[[dict], None], poll_interval: int = 5): + """ + Watch a payment and call callback on state changes. + Example: client.watch_payment("req-123", lambda e: print(f"Payment {e['event']}!")) + """ + last_state = self.get_payment_state(request_id) + while True: + current_state = self.get_payment_state(request_id) + if current_state != last_state: + last_state = current_state + callback({"request_id": request_id, "event": current_state}) + time.sleep(poll_interval) + + +# ─── Async Version ────────────────────────────────────────────────────────── + +class AsyncPaymentClient(PaymentClient): + """Async version using AsyncWeb3""" + + async def create_payment_async(self, payee: str, amount_wei: int, **kwargs) -> PaymentRequest: + """Async version of create_payment""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.create_payment, payee, amount_wei, kwargs) + + async def confirm_payment_async(self, request_id: str) -> bool: + """Async version of confirm_payment""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.confirm_payment, request_id) + + async def wait_for_confirmations_async(self, tx_hash: str) -> dict: + """Async wait for transaction confirmation""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, lambda: self.wait_for_confirmations(tx_hash) + ) + + +# ─── Convenience Functions ───────────────────────────────────────────────── + +def format_wei(wei: int, currency: str = "ETH") -> str: + """Format wei as human-readable currency amount""" + eth = Decimal(wei) / Decimal(10**18) + if currency == "ETH": + return f"{eth:.6f} ETH" + elif currency in ("USDT", "USDC"): + return f"{eth:.2f} {currency}" + return f"{eth:.6f}" + + +def parse_wei(amount: str) -> int: + """Parse human-readable amount to wei. e.g. "0.5 ETH" → wei""" + parts = amount.strip().split() + if len(parts) == 2: + num_str, currency = parts + else: + num_str = parts[0] + currency = "ETH" + + multiplier = { + "ETH": 10**18, + "wei": 1, + "KETH": 10**21, + }.get(currency.upper(), 10**18) + + return int(Decimal(num_str) * Decimal(multiplier)) + + +# ─── CLI Usage ───────────────────────────────────────────────────────────── + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Agent-to-Agent Payment Client") + parser.add_argument("--private-key", required=True, help="Wallet private key") + parser.add_argument("--escrow", required=True, help="Escrow contract address") + parser.add_argument("--rpc", required=True, help="RPC URL") + parser.add_argument("--chain-id", type=int, default=1, help="Chain ID") + parser.add_argument("--action", required=True, choices=["create", "confirm", "refund", "cancel", "status"]) + parser.add_argument("--request-id", help="Payment request ID") + parser.add_argument("--payee", help="Payee address (for create)") + parser.add_argument("--amount", help="Amount in ETH (for create), e.g. '0.1 ETH'") + parser.add_argument("--timeout", type=int, default=100, help="Timeout in blocks") + parser.add_argument("--challenge", type=int, default=10, help="Challenge period in blocks") + + args = parser.parse_args() + + client = PaymentClient(args.private_key, args.escrow, args.rpc, args.chain_id) + + if args.action == "create": + amount_wei = parse_wei(args.amount) + payee = args.payee + req = client.create_payment( + payee, amount_wei, + timeout_blocks=args.timeout, + challenge_period_blocks=args.challenge + ) + print(f"Created payment: {req.request_id}") + print(f"Amount: {format_wei(req.amount_wei)}") + print(f"Payer: {req.payer}") + print(f"Payee: {req.payee}") + + elif args.action == "confirm": + client.confirm_payment(args.request_id) + print(f"Payment {args.request_id} confirmed and released") + + elif args.action == "refund": + client.request_refund(args.request_id) + print(f"Payment {args.request_id} refunded") + + elif args.action == "cancel": + client.cancel_payment(args.request_id) + print(f"Payment {args.request_id} cancelled") + + elif args.action == "status": + state = client.get_payment_state(args.request_id) + details = client.get_payment_details(args.request_id) + print(f"State: {state}") + print(f"Details: {json.dumps(details, indent=2)}") + + +if __name__ == "__main__": + main() diff --git a/switchboard/gas_budget.py b/switchboard/gas_budget.py new file mode 100644 index 0000000..64e741d --- /dev/null +++ b/switchboard/gas_budget.py @@ -0,0 +1,233 @@ +""" +Gas budget tracker for agent wallets. + +Tracks cumulative gas spent per wallet over rolling hour and day windows, +enforces configurable limits, and pauses execution when a budget is exhausted. + +Implements issue #5: + https://github.com/kcolbchain/switchboard/issues/5 + +Design goals +------------ +- Monotonic, thread-safe accounting — safe from multiple agent worker threads. +- Rolling-window enforcement (not calendar buckets), so a burst at 23:59 does + not reset to zero one minute later. +- Pluggable clock for deterministic tests. +- Pure Python, zero new runtime deps. + +Typical usage:: + + tracker = GasBudgetTracker( + default_limits=GasLimits(per_hour=2_000_000, per_day=20_000_000), + ) + + if not tracker.can_spend(wallet, estimated_gas): + raise BudgetExhausted(tracker.status(wallet)) + + # ... send tx ... + tracker.record(wallet, gas_used=receipt.gasUsed) +""" + +from __future__ import annotations + +import threading +import time +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import Callable, Deque, Dict, Optional + + +SECONDS_PER_HOUR = 3_600 +SECONDS_PER_DAY = 86_400 + + +class BudgetExhausted(RuntimeError): + """Raised when a wallet would exceed its configured gas budget.""" + + +@dataclass(frozen=True) +class GasLimits: + """Per-wallet gas ceilings. ``None`` disables the corresponding window.""" + + per_hour: Optional[int] = None + per_day: Optional[int] = None + + +@dataclass +class BudgetStatus: + """Snapshot of a wallet's current spend vs. its limits.""" + + wallet: str + limits: GasLimits + spent_last_hour: int + spent_last_day: int + paused: bool + + @property + def remaining_hour(self) -> Optional[int]: + if self.limits.per_hour is None: + return None + return max(0, self.limits.per_hour - self.spent_last_hour) + + @property + def remaining_day(self) -> Optional[int]: + if self.limits.per_day is None: + return None + return max(0, self.limits.per_day - self.spent_last_day) + + +@dataclass +class _WalletLedger: + """Internal per-wallet state. Protected by the tracker lock.""" + + # (timestamp_seconds, gas_used) entries, oldest first. + events: Deque = field(default_factory=deque) + sum_hour: int = 0 + sum_day: int = 0 + paused: bool = False + + +class GasBudgetTracker: + """Tracks cumulative gas per wallet and enforces rolling-window limits. + + Parameters + ---------- + default_limits: + Applied to any wallet that does not have explicit limits set via + :meth:`set_limits`. + clock: + Injectable seconds-resolution clock. Defaults to :func:`time.time`. + Tests should pass a controllable clock to avoid real sleeps. + """ + + def __init__( + self, + default_limits: GasLimits = GasLimits(), + clock: Callable[[], float] = time.time, + ): + self._default_limits = default_limits + self._clock = clock + self._lock = threading.Lock() + self._ledgers: Dict[str, _WalletLedger] = defaultdict(_WalletLedger) + self._limits: Dict[str, GasLimits] = {} + + # ---- configuration ------------------------------------------------- + + def set_limits(self, wallet: str, limits: GasLimits) -> None: + """Override the default limits for ``wallet``.""" + with self._lock: + self._limits[wallet] = limits + + def limits_for(self, wallet: str) -> GasLimits: + return self._limits.get(wallet, self._default_limits) + + # ---- enforcement --------------------------------------------------- + + def can_spend(self, wallet: str, estimated_gas: int) -> bool: + """Return ``True`` if ``estimated_gas`` fits within every active window.""" + if estimated_gas < 0: + raise ValueError("estimated_gas must be non-negative") + + with self._lock: + ledger = self._ledgers[wallet] + self._evict_locked(ledger) + limits = self.limits_for(wallet) + + if ledger.paused: + return False + if limits.per_hour is not None and ledger.sum_hour + estimated_gas > limits.per_hour: + return False + if limits.per_day is not None and ledger.sum_day + estimated_gas > limits.per_day: + return False + return True + + def check(self, wallet: str, estimated_gas: int) -> None: + """Raise :class:`BudgetExhausted` if ``estimated_gas`` cannot be spent.""" + if not self.can_spend(wallet, estimated_gas): + raise BudgetExhausted(self.status(wallet)) + + def record(self, wallet: str, gas_used: int) -> BudgetStatus: + """Record a post-confirmation gas spend and return the new status. + + Auto-pauses the wallet if a limit is crossed after this record. + """ + if gas_used < 0: + raise ValueError("gas_used must be non-negative") + + with self._lock: + ledger = self._ledgers[wallet] + self._evict_locked(ledger) + + now = self._clock() + ledger.events.append((now, gas_used)) + ledger.sum_hour += gas_used + ledger.sum_day += gas_used + + limits = self.limits_for(wallet) + if ( + limits.per_hour is not None and ledger.sum_hour >= limits.per_hour + ) or ( + limits.per_day is not None and ledger.sum_day >= limits.per_day + ): + ledger.paused = True + + return self._status_locked(wallet, ledger, limits) + + # ---- introspection ------------------------------------------------- + + def status(self, wallet: str) -> BudgetStatus: + with self._lock: + ledger = self._ledgers[wallet] + self._evict_locked(ledger) + return self._status_locked(wallet, ledger, self.limits_for(wallet)) + + def resume(self, wallet: str) -> None: + """Manually unpause a wallet. The operator is responsible for ensuring + the underlying budget has freed up — this does not reset counters.""" + with self._lock: + self._ledgers[wallet].paused = False + + def reset(self, wallet: str) -> None: + """Clear all recorded spend for ``wallet`` (e.g. after a new funding round).""" + with self._lock: + self._ledgers[wallet] = _WalletLedger() + + # ---- internals ----------------------------------------------------- + + def _evict_locked(self, ledger: _WalletLedger) -> None: + """Drop events that have aged out of both windows and refresh sums.""" + now = self._clock() + day_cutoff = now - SECONDS_PER_DAY + hour_cutoff = now - SECONDS_PER_HOUR + + # Evict from the daily window (which also removes from hourly). + while ledger.events and ledger.events[0][0] <= day_cutoff: + ts, gas = ledger.events.popleft() + ledger.sum_day -= gas + if ts > hour_cutoff: + # Shouldn't happen — hour window is a subset of day — but keep + # sums consistent defensively. + ledger.sum_hour -= gas + + # Rebuild sum_hour from events (cheap: bounded by day window size). + ledger.sum_hour = sum(gas for ts, gas in ledger.events if ts > hour_cutoff) + + # Auto-unpause if limits have freed up again. + if ledger.paused: + limits_ok_hour = True + limits_ok_day = True + # We don't know wallet limits here; caller re-checks before spending. + # We keep paused sticky until explicit resume() or a fresh record() + # re-evaluates. See docstring on resume(). + del limits_ok_hour, limits_ok_day + + def _status_locked( + self, wallet: str, ledger: _WalletLedger, limits: GasLimits + ) -> BudgetStatus: + return BudgetStatus( + wallet=wallet, + limits=limits, + spent_last_hour=ledger.sum_hour, + spent_last_day=ledger.sum_day, + paused=ledger.paused, + ) diff --git a/switchboard/nonce_manager.py b/switchboard/nonce_manager.py new file mode 100644 index 0000000..036d5d9 --- /dev/null +++ b/switchboard/nonce_manager.py @@ -0,0 +1,251 @@ +import threading +from sortedcontainers import SortedSet +from typing import Dict, Any, Optional, Callable, Protocol + +class ChainClient(Protocol): + """ + Protocol for a blockchain client that provides nonce data. + A concrete implementation would interact with a specific blockchain (e.g., Ethereum RPC). + """ + def get_current_onchain_nonce(self, address: str) -> int: + """ + Fetches the current transaction count (nonce) for an address on the blockchain. + This represents the nonce of the next transaction to be sent from the address + that would be considered valid by the chain. + """ + ... + +class WalletState: + """ + Manages the local nonce state for a single wallet address. + """ + def __init__(self, confirmed_nonce: int): + # The highest sequentially confirmed nonce known to the manager. + self.confirmed_nonce: int = confirmed_nonce + + # Stores nonces that have been acquired by the manager but not yet confirmed on-chain. + # SortedSet ensures nonces are kept in order for easy processing and unique storage. + self.pending_nonces: SortedSet[int] = SortedSet() + + # Maps a pending nonce to its associated transaction object. + # This allows re-queuing of transactions if a reorg invalidates their nonces. + self.pending_transactions: Dict[int, Any] = {} + +class NonceManager: + """ + Manages nonces for multiple wallet addresses, tracking pending and confirmed + transactions and providing reorg protection. + + It ensures nonces are always valid and correctly ordered, even when + concurrent transactions are being sent or chain reorganizations occur. + """ + def __init__(self, chain_client: ChainClient, re_queue_callback: Optional[Callable[[Any], None]] = None): + """ + Initializes the NonceManager. + + Args: + chain_client: An object conforming to the ChainClient protocol, + used to interact with the blockchain to get current on-chain nonces. + re_queue_callback: An optional callback function to be invoked when + transactions need to be re-queued due to a reorg. + It should accept a single argument: the original transaction object. + """ + self._chain_client: ChainClient = chain_client + self._wallet_states: Dict[str, WalletState] = {} + self._lock = threading.Lock() # Protects access to _wallet_states for thread safety + self._re_queue_callback = re_queue_callback + + def _get_wallet_state(self, address: str) -> WalletState: + """ + Retrieves or initializes the WalletState for a given address. + This method must be called under the `_lock` to ensure thread safety. + """ + if address not in self._wallet_states: + # For a new wallet, fetch its current on-chain nonce to initialize. + onchain_nonce = self._chain_client.get_current_onchain_nonce(address) + self._wallet_states[address] = WalletState(onchain_nonce) + return self._wallet_states[address] + + def _sync_with_onchain_nonce(self, state: WalletState, address: str): + """ + Internal method to synchronize the local wallet state with the actual on-chain nonce. + This helps in resolving situations where transactions were confirmed externally + or where a reorg was resolved and new transactions got into blocks. + This method must be called under the `_lock`. + """ + onchain_nonce = self._chain_client.get_current_onchain_nonce(address) + + if onchain_nonce > state.confirmed_nonce: + # The on-chain nonce is higher than our locally confirmed nonce. + # This implies transactions have been confirmed that we might not have tracked locally, + # or previous pending nonces have been included in a block. + + # Identify and remove any local pending nonces that are now below the current + # on-chain nonce, as they are effectively confirmed. + nonces_to_remove = SortedSet(n for n in state.pending_nonces if n < onchain_nonce) + for n in nonces_to_remove: + state.pending_nonces.remove(n) + if n in state.pending_transactions: + del state.pending_transactions[n] + + # Update our locally tracked confirmed_nonce to reflect the latest on-chain state. + state.confirmed_nonce = onchain_nonce + + def acquire_nonce(self, address: str, transaction: Optional[Any] = None) -> int: + """ + Acquires the next available nonce for a given wallet address. + The acquired nonce is marked as 'pending' and associated with a transaction. + + Args: + address: The blockchain wallet address for which to acquire a nonce. + transaction: An optional transaction object to associate with this nonce. + This object will be passed to the `re_queue_callback` if a + reorg invalidates this nonce. + + Returns: + The integer value of the acquired nonce. + """ + with self._lock: + state = self._get_wallet_state(address) + + # First, ensure our local state is synchronized with the latest on-chain nonce. + self._sync_with_onchain_nonce(state, address) + + # Determine the next available nonce. + # If there are any pending nonces, the next one is the highest pending + 1. + # Otherwise, it's the current `confirmed_nonce` (which should be the next expected nonce). + next_nonce = state.confirmed_nonce + if state.pending_nonces: + next_nonce = max(state.pending_nonces) + 1 + + # Add the chosen nonce to the set of pending nonces. + state.pending_nonces.add(next_nonce) + if transaction is not None: + state.pending_transactions[next_nonce] = transaction + return next_nonce + + def release_nonce(self, address: str, nonce: int): + """ + Releases a previously acquired nonce, making it available again. + This is typically used if a transaction using this nonce failed locally + before being broadcast or was dropped from the mempool. + This method does NOT update the `confirmed_nonce` as it doesn't imply + any chain confirmation. + + Args: + address: The wallet address. + nonce: The nonce to release. + """ + with self._lock: + state = self._get_wallet_state(address) + if nonce in state.pending_nonces: + state.pending_nonces.remove(nonce) + if nonce in state.pending_transactions: + del state.pending_transactions[nonce] + # Optionally, log a warning if the nonce was not found in pending_nonces. + + def confirm_nonce(self, address: str, nonce: int): + """ + Marks a nonce as successfully confirmed on the blockchain (i.e., the transaction + using it has been mined into a block). + + Args: + address: The wallet address. + nonce: The nonce to confirm. + """ + with self._lock: + state = self._get_wallet_state(address) + + # If the nonce is currently pending, remove it. + if nonce in state.pending_nonces: + state.pending_nonces.remove(nonce) + if nonce in state.pending_transactions: + del state.pending_transactions[nonce] + elif nonce < state.confirmed_nonce: + # If the nonce is already less than the current confirmed_nonce, + # it means it was previously processed (e.g., via _sync_with_onchain_nonce). + return + + # If the confirmed nonce is sequential to our current `confirmed_nonce`, + # we can advance our `confirmed_nonce`. We also check for and confirm + # any subsequent nonces that are now also sequential. + if nonce == state.confirmed_nonce: + state.confirmed_nonce += 1 + while state.confirmed_nonce in state.pending_nonces: + state.pending_nonces.remove(state.confirmed_nonce) + if state.confirmed_nonce in state.pending_transactions: + del state.pending_transactions[state.confirmed_nonce] + state.confirmed_nonce += 1 + # If `nonce > state.confirmed_nonce` and it was not previously pending, + # it implies a gap in confirmations. We do not directly advance `state.confirmed_nonce` + # past such a gap. The `_sync_with_onchain_nonce` method will eventually correct + # `state.confirmed_nonce` if the missing nonces are confirmed on-chain. + + def on_reorg(self, address: str, reverted_to_nonce: int): + """ + Handles a chain reorganization event for a specific wallet. + This method should be called by an external chain monitor component + when a reorg is detected. + + It adjusts the `confirmed_nonce` for the affected wallet if the reorg + depth requires it and invalidates/re-queues any pending transactions + whose nonces are no longer valid due to the reorg. + + Args: + address: The wallet address affected by the reorg. + reverted_to_nonce: The highest nonce that is considered confirmed + and valid at the common ancestor block after the reorg. + All transactions with nonces equal to or greater than + `reverted_to_nonce` are considered potentially invalid. + """ + with self._lock: + state = self._get_wallet_state(address) + + # If the reorg depth implies that our `confirmed_nonce` is no longer valid, + # revert it to the `reverted_to_nonce` supplied by the reorg detector. + if state.confirmed_nonce > reverted_to_nonce: + state.confirmed_nonce = reverted_to_nonce + + reverted_txns = [] + nonces_to_remove = SortedSet() + + # Identify all pending nonces that are equal to or greater than `reverted_to_nonce`. + # These nonces are now invalid and their associated transactions need to be re-queued. + for nonce in state.pending_nonces: + if nonce >= reverted_to_nonce: + nonces_to_remove.add(nonce) + if nonce in state.pending_transactions: + reverted_txns.append(state.pending_transactions[nonce]) + + # Remove identified invalid nonces and their associated transactions from our local state. + for nonce in nonces_to_remove: + state.pending_nonces.remove(nonce) + del state.pending_transactions[nonce] + + # If a `re_queue_callback` was provided, invoke it for all identified reverted transactions. + if self._re_queue_callback and reverted_txns: + for tx in reverted_txns: + self._re_queue_callback(tx) + + def get_pending_nonces(self, address: str) -> SortedSet[int]: + """ + Returns a copy of the set of nonces currently marked as pending for an address. + """ + with self._lock: + return SortedSet(self._get_wallet_state(address).pending_nonces) + + def get_confirmed_nonce(self, address: str) -> int: + """ + Returns the highest sequentially confirmed nonce for an address known to the manager. + """ + with self._lock: + return self._get_wallet_state(address).confirmed_nonce + + def get_total_pending_transactions(self, address: str) -> int: + """ + Returns the count of transactions currently pending (acquired but not confirmed) + for a specific address. + """ + with self._lock: + return len(self._get_wallet_state(address).pending_transactions) + diff --git a/switchboard/x402_middleware.py b/switchboard/x402_middleware.py new file mode 100644 index 0000000..28273ab --- /dev/null +++ b/switchboard/x402_middleware.py @@ -0,0 +1,284 @@ +""" +x402 Server-Side Middleware for Switchboard + +Implements HTTP 402 Payment Required flows for agent-to-agent API monetization. +When a server returns 402, this middleware automatically handles payment via +the switchboard PaymentClient, then retries the original request with a +payment proof header. + +Supports: +- Automatic 402 detection and payment handling +- Budget-aware payment gating (integrates with GasTracker) +- Payment proof via X-Payment-Proof header +- USDC and ETH settlement +- Configurable per-endpoint pricing caps + +Usage: + middleware = X402Middleware( + payment_client=client, + gas_tracker=tracker, + max_payment_usd=Decimal("1.00"), + ) + response = await middleware.request("https://agent.example.com/inference", payload) + +References: + - https://github.com/coinbase/x402 + - EIP-7702 for smart account payments +""" + +import asyncio +import hashlib +import json +import time +from dataclasses import dataclass, field +from decimal import Decimal +from typing import Optional, Dict, Any, Callable +from enum import Enum + +try: + import aiohttp + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + + +class PaymentScheme(Enum): + """Supported x402 payment schemes.""" + EXACT = "exact" # Pay exact amount specified in 402 response + ESCROW = "escrow" # Lock in escrow, release on delivery + STREAMING = "streaming" # Micro-payments per chunk (future: MPP) + + +@dataclass +class PaymentOffer: + """Parsed from the 402 response's X-Payment-Required header.""" + amount_wei: int + currency: str # "ETH", "USDC", etc. + recipient: str # Payee address + chain_id: int + scheme: PaymentScheme = PaymentScheme.EXACT + description: str = "" + endpoint: str = "" + nonce: str = "" + expires_at: Optional[int] = None # Unix timestamp + + @classmethod + def from_header(cls, header_value: str, endpoint: str = "") -> "PaymentOffer": + """Parse X-Payment-Required header JSON.""" + data = json.loads(header_value) + return cls( + amount_wei=int(data["amount"]), + currency=data.get("currency", "ETH"), + recipient=data["recipient"], + chain_id=int(data.get("chainId", 1)), + scheme=PaymentScheme(data.get("scheme", "exact")), + description=data.get("description", ""), + endpoint=endpoint, + nonce=data.get("nonce", ""), + expires_at=data.get("expiresAt"), + ) + + def is_expired(self) -> bool: + if self.expires_at is None: + return False + return time.time() > self.expires_at + + +@dataclass +class PaymentProof: + """Proof that payment was made, sent back to the server.""" + tx_hash: str + chain_id: int + payer: str + amount_wei: int + nonce: str = "" + timestamp: float = field(default_factory=time.time) + + def to_header(self) -> str: + return json.dumps({ + "txHash": self.tx_hash, + "chainId": self.chain_id, + "payer": self.payer, + "amount": self.amount_wei, + "nonce": self.nonce, + "timestamp": int(self.timestamp), + }) + + +@dataclass +class PaymentRecord: + """Log of a completed payment.""" + endpoint: str + offer: PaymentOffer + proof: PaymentProof + response_status: int + paid_at: float = field(default_factory=time.time) + + +class X402Middleware: + """ + HTTP middleware that intercepts 402 responses and pays automatically. + + Integrates with: + - PaymentClient for on-chain settlement + - GasTracker for budget enforcement + """ + + def __init__( + self, + payment_client, + gas_tracker=None, + max_payment_wei: int = 10**16, # 0.01 ETH default cap + allowed_recipients: Optional[set] = None, + auto_pay: bool = True, + on_payment: Optional[Callable[[PaymentRecord], None]] = None, + ): + if not HAS_AIOHTTP: + raise ImportError("aiohttp required: pip install aiohttp") + + self.payment_client = payment_client + self.gas_tracker = gas_tracker + self.max_payment_wei = max_payment_wei + self.allowed_recipients = allowed_recipients + self.auto_pay = auto_pay + self.on_payment = on_payment + + self.payment_history: list[PaymentRecord] = [] + self.total_spent_wei: int = 0 + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + + def _validate_offer(self, offer: PaymentOffer) -> None: + """Check offer against policy before paying.""" + if offer.is_expired(): + raise ValueError(f"Payment offer expired at {offer.expires_at}") + + if offer.amount_wei > self.max_payment_wei: + raise ValueError( + f"Payment {offer.amount_wei} exceeds cap {self.max_payment_wei}" + ) + + if self.allowed_recipients and offer.recipient not in self.allowed_recipients: + raise ValueError(f"Recipient {offer.recipient} not in allowlist") + + if self.gas_tracker: + if not self.gas_tracker.can_send_transaction(offer.amount_wei): + raise ValueError("Payment would exceed gas budget") + + def _pay_onchain(self, offer: PaymentOffer) -> PaymentProof: + """Execute on-chain payment via PaymentClient.""" + if offer.scheme == PaymentScheme.EXACT: + # Direct transfer — build and send a simple value transfer + tx = { + "to": offer.recipient, + "value": offer.amount_wei, + "from": self.payment_client.wallet_address, + } + tx_hash = self.payment_client.sign_and_send(tx) + self.payment_client.wait_for_confirmations(tx_hash) + + return PaymentProof( + tx_hash=tx_hash, + chain_id=offer.chain_id, + payer=self.payment_client.wallet_address, + amount_wei=offer.amount_wei, + nonce=offer.nonce, + ) + + elif offer.scheme == PaymentScheme.ESCROW: + req = self.payment_client.create_payment( + payee=offer.recipient, + amount_wei=offer.amount_wei, + timeout_blocks=50, + description=offer.description, + ) + return PaymentProof( + tx_hash=req.request_id, + chain_id=offer.chain_id, + payer=self.payment_client.wallet_address, + amount_wei=offer.amount_wei, + nonce=offer.nonce, + ) + + else: + raise ValueError(f"Unsupported payment scheme: {offer.scheme}") + + async def request( + self, + url: str, + payload: Any = None, + method: str = "POST", + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> aiohttp.ClientResponse: + """ + Make an HTTP request. If the server returns 402, automatically + pay and retry with payment proof. + """ + session = await self._get_session() + headers = dict(headers or {}) + + # First attempt + if method == "POST": + resp = await session.post(url, json=payload, headers=headers, **kwargs) + else: + resp = await session.request(method, url, headers=headers, **kwargs) + + if resp.status != 402 or not self.auto_pay: + return resp + + # Parse 402 payment offer + payment_header = resp.headers.get("X-Payment-Required") + if not payment_header: + return resp # 402 without payment header — can't auto-pay + + offer = PaymentOffer.from_header(payment_header, endpoint=url) + self._validate_offer(offer) + + # Pay on-chain + proof = self._pay_onchain(offer) + + # Record payment + if self.gas_tracker: + self.gas_tracker.record_gas_usage(offer.amount_wei) + self.total_spent_wei += offer.amount_wei + + # Retry with payment proof + headers["X-Payment-Proof"] = proof.to_header() + if method == "POST": + resp2 = await session.post(url, json=payload, headers=headers, **kwargs) + else: + resp2 = await session.request(method, url, headers=headers, **kwargs) + + record = PaymentRecord( + endpoint=url, + offer=offer, + proof=proof, + response_status=resp2.status, + ) + self.payment_history.append(record) + if self.on_payment: + self.on_payment(record) + + return resp2 + + def get_spend_summary(self) -> dict: + """Return summary of all payments made.""" + by_endpoint: Dict[str, int] = {} + for record in self.payment_history: + by_endpoint[record.endpoint] = ( + by_endpoint.get(record.endpoint, 0) + record.offer.amount_wei + ) + return { + "total_payments": len(self.payment_history), + "total_spent_wei": self.total_spent_wei, + "by_endpoint": by_endpoint, + } diff --git a/switchboard/zap_transport.py b/switchboard/zap_transport.py new file mode 100644 index 0000000..3c80f8b --- /dev/null +++ b/switchboard/zap_transport.py @@ -0,0 +1,245 @@ +"""ZAP wire transport for switchboard payment flows. + +Switchboard's existing transport encodes ``PaymentOffer`` / ``PaymentProof`` +as JSON in HTTP headers. That is fine for HTTP/REST agents but expensive +for high-volume agent-to-agent traffic — every offer is parsed, allocated, +and copied. This module adds a binary alternative: encode an offer or a +proof as a `ZAP `_ message so two agents +sitting on the same Lux network (port 9999) can exchange them without +HTTP, JSON, or per-call allocation. + +The wire layout is a fixed ZAP struct schema declared up front, so any +ZAP-speaking language (Go via ``luxfi/zap`` upstream, Python via +``zap_py``) reads and writes the same bytes. Field offsets and total +struct size are pinned by tests against the canonical +``StructBuilder.build()`` output. + +zap_py is an *optional* dependency. If it isn't installed, +``encode_offer`` / ``decode_offer`` raise ``ZapNotAvailable`` and callers +should fall back to the existing JSON path. Tests are skipped via +``pytest.importorskip`` so the suite stays green either way. + +Install (until ``luxfi-zap`` is on PyPI):: + + pip install 'luxfi-zap @ git+https://github.com/luxfi/zap@main#subdirectory=python' + +References +---------- +- ``switchboard.x402_middleware.PaymentOffer`` / ``PaymentProof`` +- luxfi/zap (Go reference) + ``python/zap_py`` (parity-tested Python port) +""" + +from __future__ import annotations + +from typing import Optional, Tuple + +from .x402_middleware import PaymentOffer, PaymentProof, PaymentScheme + +try: + from zap_py import ( + ADDRESS_SIZE, + Address, + Builder, + HASH_SIZE, + Hash, + StructBuilder, + Type, + address_from_hex, + parse, + ) + + HAS_ZAP_PY = True +except ImportError: # pragma: no cover — exercised by environment, not tests + HAS_ZAP_PY = False + + +__all__ = [ + "HAS_ZAP_PY", + "ZapNotAvailable", + "OFFER_SCHEMA", + "PROOF_SCHEMA", + "encode_offer", + "decode_offer", + "encode_proof", + "decode_proof", +] + + +class ZapNotAvailable(RuntimeError): + """zap_py is not installed; ZAP transport is unavailable.""" + + +# ─── Wire constants ────────────────────────────────────────────────────────── +# +# Both schemas use uint256 amount-as-bytes (32 LE bytes) so we don't truncate +# realistic on-chain values into a uint64 — the JSON path already accepts +# arbitrarily large `int`s and we want byte-for-byte interop with that. +_AMOUNT_BYTES = 32 + +# ``scheme`` is encoded as uint8. Order matches the wire intent, not the Python +# enum's iteration order — pin it here so a Go implementation can mirror it. +_SCHEME_TO_WIRE = { + PaymentScheme.EXACT: 0, + PaymentScheme.ESCROW: 1, + PaymentScheme.STREAMING: 2, +} +_WIRE_TO_SCHEME = {v: k for k, v in _SCHEME_TO_WIRE.items()} + + +def _build_offer_schema(): + if not HAS_ZAP_PY: + return None + return ( + StructBuilder("SwitchboardPaymentOffer") + .uint8("scheme") + .uint64("chain_id") + .uint64("expires_at") # 0 sentinel = "no expiry" + .address("recipient") + .bytes("amount") # 32-byte big-endian uint256 + .text("currency") + .text("description") + .text("endpoint") + .text("nonce") + .build() + ) + + +def _build_proof_schema(): + if not HAS_ZAP_PY: + return None + return ( + StructBuilder("SwitchboardPaymentProof") + .uint64("chain_id") + .uint64("timestamp") + .address("payer") + .hash("tx_hash") + .bytes("amount") # 32-byte big-endian uint256 + .text("nonce") + .build() + ) + + +OFFER_SCHEMA = _build_offer_schema() +PROOF_SCHEMA = _build_proof_schema() + + +def _require_zap() -> None: + if not HAS_ZAP_PY: + raise ZapNotAvailable( + "zap_py is not installed; install luxfi-zap (see switchboard/zap_transport.py)" + ) + + +def _amount_to_bytes(amount: int) -> bytes: + if amount < 0: + raise ValueError("amount must be non-negative") + if amount.bit_length() > _AMOUNT_BYTES * 8: + raise ValueError(f"amount exceeds uint{_AMOUNT_BYTES * 8}") + return amount.to_bytes(_AMOUNT_BYTES, "big") + + +def _amount_from_bytes(data: bytes) -> int: + if len(data) != _AMOUNT_BYTES: + raise ValueError(f"amount field must be {_AMOUNT_BYTES} bytes, got {len(data)}") + return int.from_bytes(data, "big") + + +def _addr_to_bytes(s: str) -> bytes: + """Accept a 0x-prefixed hex address; return raw 20 bytes.""" + return address_from_hex(s).bytes + + +def _addr_to_hex(addr) -> str: + return addr.hex() + + +# ─── PaymentOffer ──────────────────────────────────────────────────────────── + + +def encode_offer(offer: PaymentOffer) -> bytes: + """Serialize a PaymentOffer to a ZAP wire message (zero allocations on read).""" + _require_zap() + f = {fld.name: fld.offset for fld in OFFER_SCHEMA.fields} + + b = Builder() + ob = b.start_object(OFFER_SCHEMA.size) + ob.set_uint8(f["scheme"], _SCHEME_TO_WIRE[offer.scheme]) + ob.set_uint64(f["chain_id"], offer.chain_id) + ob.set_uint64(f["expires_at"], offer.expires_at or 0) + ob.set_address(f["recipient"], _addr_to_bytes(offer.recipient)) + ob.set_bytes(f["amount"], _amount_to_bytes(offer.amount_wei)) + ob.set_text(f["currency"], offer.currency) + ob.set_text(f["description"], offer.description) + ob.set_text(f["endpoint"], offer.endpoint) + ob.set_text(f["nonce"], offer.nonce) + ob.finish_as_root() + return b.finish() + + +def decode_offer(wire: bytes) -> PaymentOffer: + """Parse a ZAP wire message into a PaymentOffer.""" + _require_zap() + f = {fld.name: fld.offset for fld in OFFER_SCHEMA.fields} + + msg = parse(wire) + root = msg.root() + + expires = root.uint64(f["expires_at"]) + return PaymentOffer( + amount_wei=_amount_from_bytes(root.bytes(f["amount"])), + currency=root.text(f["currency"]), + recipient=_addr_to_hex(root.address(f["recipient"])), + chain_id=root.uint64(f["chain_id"]), + scheme=_WIRE_TO_SCHEME[root.uint8(f["scheme"])], + description=root.text(f["description"]), + endpoint=root.text(f["endpoint"]), + nonce=root.text(f["nonce"]), + expires_at=int(expires) if expires else None, + ) + + +# ─── PaymentProof ──────────────────────────────────────────────────────────── + + +def _hash_from_hex(s: str) -> bytes: + if s.startswith("0x") or s.startswith("0X"): + s = s[2:] + raw = bytes.fromhex(s) + if len(raw) != HASH_SIZE: + raise ValueError(f"tx_hash must be {HASH_SIZE} bytes, got {len(raw)}") + return raw + + +def encode_proof(proof: PaymentProof) -> bytes: + """Serialize a PaymentProof to a ZAP wire message.""" + _require_zap() + f = {fld.name: fld.offset for fld in PROOF_SCHEMA.fields} + + b = Builder() + ob = b.start_object(PROOF_SCHEMA.size) + ob.set_uint64(f["chain_id"], proof.chain_id) + ob.set_uint64(f["timestamp"], int(proof.timestamp)) + ob.set_address(f["payer"], _addr_to_bytes(proof.payer)) + ob.set_hash(f["tx_hash"], _hash_from_hex(proof.tx_hash)) + ob.set_bytes(f["amount"], _amount_to_bytes(proof.amount_wei)) + ob.set_text(f["nonce"], proof.nonce) + ob.finish_as_root() + return b.finish() + + +def decode_proof(wire: bytes) -> PaymentProof: + """Parse a ZAP wire message into a PaymentProof.""" + _require_zap() + f = {fld.name: fld.offset for fld in PROOF_SCHEMA.fields} + + msg = parse(wire) + root = msg.root() + + return PaymentProof( + tx_hash=root.hash(f["tx_hash"]).hex(), + chain_id=root.uint64(f["chain_id"]), + payer=_addr_to_hex(root.address(f["payer"])), + amount_wei=_amount_from_bytes(root.bytes(f["amount"])), + nonce=root.text(f["nonce"]), + timestamp=float(root.uint64(f["timestamp"])), + ) diff --git a/tests/test_gas_budget.py b/tests/test_gas_budget.py new file mode 100644 index 0000000..9b12f21 --- /dev/null +++ b/tests/test_gas_budget.py @@ -0,0 +1,224 @@ +"""Tests for switchboard.gas_budget — see issue #5.""" + +from __future__ import annotations + +import threading + +import pytest + +from switchboard.gas_budget import ( + BudgetExhausted, + GasBudgetTracker, + GasLimits, + SECONDS_PER_DAY, + SECONDS_PER_HOUR, +) + + +class FakeClock: + """Deterministic monotonically-controllable clock.""" + + def __init__(self, start: float = 1_700_000_000.0): + self._t = start + + def __call__(self) -> float: + return self._t + + def advance(self, seconds: float) -> None: + self._t += seconds + + +WALLET = "0xAgent" + + +# ---------------------------------------------------------------- basics + + +def test_default_limits_allow_everything(): + t = GasBudgetTracker() + assert t.can_spend(WALLET, 10**12) is True + t.record(WALLET, 10**12) + status = t.status(WALLET) + assert status.paused is False + assert status.remaining_hour is None + assert status.remaining_day is None + + +def test_record_rejects_negative(): + t = GasBudgetTracker() + with pytest.raises(ValueError): + t.record(WALLET, -1) + with pytest.raises(ValueError): + t.can_spend(WALLET, -1) + + +# ---------------------------------------------------------------- hour + + +def test_hourly_limit_blocks_overspend(): + clock = FakeClock() + t = GasBudgetTracker( + default_limits=GasLimits(per_hour=100_000), + clock=clock, + ) + + assert t.can_spend(WALLET, 60_000) + t.record(WALLET, 60_000) + + # 60k of 100k used — 50k more would exceed. + assert t.can_spend(WALLET, 30_000) is True + assert t.can_spend(WALLET, 50_000) is False + + +def test_hourly_window_rolls_forward(): + clock = FakeClock() + t = GasBudgetTracker( + default_limits=GasLimits(per_hour=100_000), + clock=clock, + ) + + t.record(WALLET, 90_000) + assert t.can_spend(WALLET, 20_000) is False + + # Slide past the hour boundary. + clock.advance(SECONDS_PER_HOUR + 1) + + # Spend should now fit — but wallet remains paused until operator resumes + # (prior spend exhausted the limit, pausing it). Resume and retry. + t.resume(WALLET) + assert t.can_spend(WALLET, 90_000) is True + + +# ---------------------------------------------------------------- day + + +def test_daily_limit_independent_of_hourly(): + clock = FakeClock() + t = GasBudgetTracker( + default_limits=GasLimits(per_hour=50_000, per_day=120_000), + clock=clock, + ) + + for _ in range(3): + assert t.can_spend(WALLET, 40_000) + t.record(WALLET, 40_000) + clock.advance(SECONDS_PER_HOUR + 1) + t.resume(WALLET) # re-enable after each hourly pause + + # Day total now 120k == limit exactly; anything more should fail. + assert t.can_spend(WALLET, 1) is False + + +def test_daily_window_rolls_forward(): + clock = FakeClock() + t = GasBudgetTracker( + default_limits=GasLimits(per_day=100_000), + clock=clock, + ) + + t.record(WALLET, 100_000) + assert t.status(WALLET).paused is True + + clock.advance(SECONDS_PER_DAY + 1) + t.resume(WALLET) + assert t.can_spend(WALLET, 100_000) is True + assert t.status(WALLET).spent_last_day == 0 + + +# ---------------------------------------------------------------- pause + + +def test_pause_on_exhaustion_blocks_further_spending(): + t = GasBudgetTracker(default_limits=GasLimits(per_hour=1_000)) + t.record(WALLET, 1_000) + status = t.status(WALLET) + assert status.paused is True + assert t.can_spend(WALLET, 1) is False + + +def test_check_raises_when_exhausted(): + t = GasBudgetTracker(default_limits=GasLimits(per_hour=500)) + t.record(WALLET, 500) + with pytest.raises(BudgetExhausted) as exc: + t.check(WALLET, 1) + assert exc.value.args[0].wallet == WALLET + + +def test_resume_clears_pause_without_resetting_counters(): + clock = FakeClock() + t = GasBudgetTracker(default_limits=GasLimits(per_hour=1_000), clock=clock) + t.record(WALLET, 1_000) + assert t.status(WALLET).paused is True + + t.resume(WALLET) + status = t.status(WALLET) + assert status.paused is False + assert status.spent_last_hour == 1_000 # counters intact + + +# ---------------------------------------------------------------- per-wallet + + +def test_per_wallet_limits_override_default(): + t = GasBudgetTracker(default_limits=GasLimits(per_hour=1_000)) + t.set_limits("0xVIP", GasLimits(per_hour=10_000)) + + t.record("0xVIP", 5_000) + t.record(WALLET, 900) + assert t.status("0xVIP").paused is False + assert t.can_spend(WALLET, 500) is False # default limit binds + + +def test_wallets_are_isolated(): + t = GasBudgetTracker(default_limits=GasLimits(per_hour=1_000)) + t.record("a", 1_000) + assert t.status("a").paused is True + assert t.status("b").paused is False + assert t.can_spend("b", 999) is True + + +# ---------------------------------------------------------------- reset + + +def test_reset_clears_history(): + t = GasBudgetTracker(default_limits=GasLimits(per_hour=100)) + t.record(WALLET, 100) + t.reset(WALLET) + s = t.status(WALLET) + assert s.spent_last_hour == 0 + assert s.paused is False + assert t.can_spend(WALLET, 100) is True + + +# ---------------------------------------------------------------- threads + + +def test_thread_safety_sum_is_exact(): + t = GasBudgetTracker() # no limits; just exercise locking + N = 500 + workers = 8 + + def run(): + for _ in range(N): + t.record(WALLET, 1) + + threads = [threading.Thread(target=run) for _ in range(workers)] + for th in threads: + th.start() + for th in threads: + th.join() + + assert t.status(WALLET).spent_last_hour == N * workers + + +# ---------------------------------------------------------------- status + + +def test_status_reports_remaining(): + t = GasBudgetTracker( + default_limits=GasLimits(per_hour=10_000, per_day=100_000), + ) + t.record(WALLET, 3_000) + s = t.status(WALLET) + assert s.remaining_hour == 7_000 + assert s.remaining_day == 97_000 diff --git a/tests/test_nonce_manager.py b/tests/test_nonce_manager.py new file mode 100644 index 0000000..02a6d39 --- /dev/null +++ b/tests/test_nonce_manager.py @@ -0,0 +1,384 @@ +import unittest +import threading +from typing import Dict, Any, List, Callable +from sortedcontainers import SortedSet + +# Assuming nonce_manager.py is correctly importable from 'switchboard' package +from switchboard.nonce_manager import NonceManager, ChainClient + +# --- Mock ChainClient for testing --- +class MockChainClientImpl: + """ + A mock implementation of the ChainClient protocol for testing purposes. + Allows simulating on-chain nonce changes. + """ + def __init__(self, initial_onchain_nonces: Dict[str, int]): + self._onchain_nonces = initial_onchain_nonces + self._lock = threading.Lock() + + def get_current_onchain_nonce(self, address: str) -> int: + """ + Returns the simulated current transaction count (nonce) for an address. + """ + with self._lock: + return self._onchain_nonces.get(address, 0) + + def set_onchain_nonce(self, address: str, nonce: int): + """ + Simulates an external confirmation or a direct change in the blockchain's + reported nonce for an address. + """ + with self._lock: + self._onchain_nonces[address] = nonce + +class MockTransaction: + """ + A simple mock transaction object to be associated with nonces and re-queued. + """ + def __init__(self, nonce: int, content: str): + self.nonce = nonce + self.content = content + self.re_queued_count = 0 # To track how many times it was re-queued + + def __repr__(self): + return f"MockTransaction(nonce={self.nonce}, content='{self.content}', re_queued_count={self.re_queued_count})" + + def __eq__(self, other): + if not isinstance(other, MockTransaction): + return NotImplemented + return self.nonce == other.nonce and self.content == other.content + + +# --- Unit Tests for NonceManager --- +class TestNonceManager(unittest.TestCase): + def setUp(self): + """ + Set up shared resources for each test case. + Initializes the mock chain client and NonceManager. + """ + self.wallet_address_1 = "0xAgentWallet1" + self.wallet_address_2 = "0xAgentWallet2" + self.initial_nonces = { + self.wallet_address_1: 0, + self.wallet_address_2: 5, # Simulate an agent that already has some txns confirmed + } + self.mock_chain_client = MockChainClientImpl(self.initial_nonces.copy()) + + # List to capture transactions passed to the re_queue_callback + self.re_queued_txns: List[MockTransaction] = [] + + # Define the re-queue callback function + def re_queue_callback(tx: MockTransaction): + tx.re_queued_count += 1 + self.re_queued_txns.append(tx) + + self.nonce_manager = NonceManager(self.mock_chain_client, re_queue_callback) + + def test_initial_state_and_acquire_nonce(self): + """ + Tests the initial state of wallets and basic nonce acquisition. + """ + # Wallet 1: Starts with 0 on-chain nonce + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) + + # Acquire first nonce for Wallet 1 (should be 0) + tx1_0 = MockTransaction(0, "tx_0_w1") + nonce0 = self.nonce_manager.acquire_nonce(self.wallet_address_1, tx1_0) + self.assertEqual(nonce0, 0) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 1) + + # Acquire second nonce for Wallet 1 (should be 1) + tx1_1 = MockTransaction(1, "tx_1_w1") + nonce1 = self.nonce_manager.acquire_nonce(self.wallet_address_1, tx1_1) + self.assertEqual(nonce1, 1) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0, 1])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 2) + + # Wallet 2: Starts with 5 on-chain nonce + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_2), 5) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_2), SortedSet()) + + # Acquire first nonce for Wallet 2 (should be 5) + tx2_5 = MockTransaction(5, "tx_5_w2") + nonce5 = self.nonce_manager.acquire_nonce(self.wallet_address_2, tx2_5) + self.assertEqual(nonce5, 5) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_2), 5) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_2), SortedSet([5])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_2), 1) + + def test_confirm_nonce_sequential(self): + """ + Tests nonce confirmation when transactions are mined in sequential order. + """ + # Acquire some nonces for Wallet 1 + tx_w1_0 = MockTransaction(0, "w1_0") + tx_w1_1 = MockTransaction(1, "w1_1") + tx_w1_2 = MockTransaction(2, "w1_2") + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_0) # nonce 0 + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_1) # nonce 1 + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_2) # nonce 2 + + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0, 1, 2])) + + # Confirm nonce 0 + self.nonce_manager.confirm_nonce(self.wallet_address_1, 0) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 1) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([1, 2])) + + # Confirm nonce 1 + self.nonce_manager.confirm_nonce(self.wallet_address_1, 1) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 2) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([2])) + + # Confirm nonce 2 + self.nonce_manager.confirm_nonce(self.wallet_address_1, 2) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 3) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 0) + + def test_confirm_nonce_out_of_order_or_gap(self): + """ + Tests nonce confirmation when transactions are mined out of sequential order. + """ + # Acquire nonces 0, 1, 2 + tx_w1_0 = MockTransaction(0, "w1_0") + tx_w1_1 = MockTransaction(1, "w1_1") + tx_w1_2 = MockTransaction(2, "w1_2") + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_0) + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_1) + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_2) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0, 1, 2])) + + # Confirm nonce 2 directly (out of order). Confirmed_nonce should NOT advance past 0 + # because nonces 0 and 1 are still pending. + self.nonce_manager.confirm_nonce(self.wallet_address_1, 2) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0, 1])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 2) + + # Confirm nonce 0. This will advance confirmed_nonce to 1. + self.nonce_manager.confirm_nonce(self.wallet_address_1, 0) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 1) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([1])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 1) + + # Confirm nonce 1. This will advance confirmed_nonce to 3 (because nonce 2 was already handled). + self.nonce_manager.confirm_nonce(self.wallet_address_1, 1) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 3) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 0) + + def test_release_nonce(self): + """ + Tests releasing a pending nonce, e.g., if a transaction is dropped. + """ + tx_w1_0 = MockTransaction(0, "w1_0") + tx_w1_1 = MockTransaction(1, "w1_1") + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_0) + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_1) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0, 1])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 2) + + # Release nonce 1 + self.nonce_manager.release_nonce(self.wallet_address_1, 1) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 1) + + # Confirm nonce 0 + self.nonce_manager.confirm_nonce(self.wallet_address_1, 0) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 1) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) + + # Attempt to release a nonce that was never pending or already confirmed, should have no effect + self.nonce_manager.release_nonce(self.wallet_address_1, 5) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) + + def test_sync_with_onchain_nonce_external_confirmation(self): + """ + Tests synchronization with the on-chain nonce when external transactions + or unknown confirmations have advanced the chain state. + """ + # Initial state: confirmed_nonce = 0, pending = {} + # Acquire nonces 0, 1, 2 locally + tx_w1_0 = MockTransaction(0, "w1_0") + tx_w1_1 = MockTransaction(1, "w1_1") + tx_w1_2 = MockTransaction(2, "w1_2") + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_0) + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_1) + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_2) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0, 1, 2])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 3) + + # Simulate external confirmation of nonces 0 and 1, so the chain's next nonce is 2. + self.mock_chain_client.set_onchain_nonce(self.wallet_address_1, 2) + + # When `acquire_nonce` is called again, it will trigger `_sync_with_onchain_nonce`. + tx_w1_new = MockTransaction(2, "w1_new") # New transaction trying to acquire a nonce + acquired_nonce = self.nonce_manager.acquire_nonce(self.wallet_address_1, tx_w1_new) + + # Expect the manager to have synced: confirmed_nonce moves to 2. + # Pending nonces < 2 (i.e., 0 and 1) are removed. + # The new transaction acquires nonce 2. + self.assertEqual(acquired_nonce, 2) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 2) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([2])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 1) + + # Verify that the new transaction has overwritten the old one for nonce 2 in pending_transactions + state = self.nonce_manager._get_wallet_state(self.wallet_address_1) + self.assertIn(2, state.pending_transactions) + self.assertEqual(state.pending_transactions[2].content, "w1_new") + + def test_on_reorg(self): + """ + Tests the `on_reorg` mechanism, ensuring nonces are reverted and transactions re-queued. + """ + # Acquire nonces 0, 1, 2, 3 + txs = {} + for i in range(4): + tx = MockTransaction(i, f"w1_{i}") + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx) + txs[i] = tx + + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([0, 1, 2, 3])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 4) + + # Confirm nonce 0 and 1 + self.nonce_manager.confirm_nonce(self.wallet_address_1, 0) + self.nonce_manager.confirm_nonce(self.wallet_address_1, 1) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 2) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([2, 3])) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 2) + self.assertEqual(len(self.re_queued_txns), 0) + + # Simulate a reorg where the chain reverts back to nonce 1 as the common ancestor's nonce. + # This means transactions with nonce 1 and higher are potentially invalid. + # Our local `confirmed_nonce` is 2, so it will be reverted to 1. + self.nonce_manager.on_reorg(self.wallet_address_1, 1) + + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 1) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) # All pending (2, 3) are now gone + self.assertEqual(self.nonce_manager.get_total_pending_transactions(self.wallet_address_1), 0) + + # Check that the affected transactions were re-queued + self.assertEqual(len(self.re_queued_txns), 2) + re_queued_nonces = {tx.nonce for tx in self.re_queued_txns} + self.assertIn(2, re_queued_nonces) + self.assertIn(3, re_queued_nonces) + self.assertEqual(txs[2].re_queued_count, 1) + self.assertEqual(txs[3].re_queued_count, 1) + + # Acquire a new nonce after the reorg; it should now correctly pick up from the reverted state. + new_tx = MockTransaction(1, "w1_new_after_reorg") + new_nonce = self.nonce_manager.acquire_nonce(self.wallet_address_1, new_tx) + self.assertEqual(new_nonce, 1) # Should now acquire nonce 1 again + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet([1])) + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 1) + + def test_on_reorg_deeper_than_pending(self): + """ + Tests reorg handling when the reorg depth affects previously confirmed nonces. + """ + # Acquire nonces 0, 1, 2 + txs = {} + for i in range(3): + tx = MockTransaction(i, f"w1_{i}") + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx) + txs[i] = tx + + # Confirm nonces 0 and 1 + self.nonce_manager.confirm_nonce(self.wallet_address_1, 0) + self.nonce_manager.confirm_nonce(self.wallet_address_1, 1) + # Current state: confirmed_nonce = 2, pending = {2} + + # Simulate a deep reorg to common ancestor nonce 0. + # This means confirmed_nonce (2) should revert to 0. + # Pending nonce 2 should also be invalidated. + self.nonce_manager.on_reorg(self.wallet_address_1, 0) + + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 0) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) + self.assertEqual(len(self.re_queued_txns), 1) # Only txs[2] was pending and >= 0 + self.assertEqual(self.re_queued_txns[0].nonce, 2) + + def test_on_reorg_no_effect_on_confirmed(self): + """ + Tests reorg handling where the common ancestor nonce is equal to the current + confirmed_nonce, only affecting pending transactions. + """ + # Acquire nonces 0, 1 + txs = {} + for i in range(2): + tx = MockTransaction(i, f"w1_{i}") + self.nonce_manager.acquire_nonce(self.wallet_address_1, tx) + txs[i] = tx + + # Confirm nonce 0 + self.nonce_manager.confirm_nonce(self.wallet_address_1, 0) + # Current state: confirmed_nonce = 1, pending = {1} + + # Simulate reorg to common ancestor nonce 1. + # `confirmed_nonce` (1) matches `reverted_to_nonce` (1), so confirmed_nonce doesn't change. + # Only pending nonces >= 1 (i.e., pending nonce 1) are affected. + self.nonce_manager.on_reorg(self.wallet_address_1, 1) + + self.assertEqual(self.nonce_manager.get_confirmed_nonce(self.wallet_address_1), 1) + self.assertEqual(self.nonce_manager.get_pending_nonces(self.wallet_address_1), SortedSet()) # Pending 1 removed + self.assertEqual(len(self.re_queued_txns), 1) + self.assertEqual(self.re_queued_txns[0].nonce, 1) + + def test_concurrent_access(self): + """ + Tests thread safety of `acquire_nonce` under concurrent access. + """ + num_threads = 5 + num_tx_per_thread = 10 + wallet = self.wallet_address_1 + + # Simulate initial external confirmation to set a starting point. + # `acquire_nonce` will sync with this, setting confirmed_nonce to 10. + self.mock_chain_client.set_onchain_nonce(wallet, 10) + + def agent_task(): + for _ in range(num_tx_per_thread): + tx = MockTransaction(0, "concurrent_tx") # Nonce will be assigned by manager + self.nonce_manager.acquire_nonce(wallet, tx) + + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=agent_task) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # After all `acquire_nonce` calls, check the total number of pending nonces + pending = self.nonce_manager.get_pending_nonces(wallet) + self.assertEqual(len(pending), num_threads * num_tx_per_thread) + + # All acquired nonces should be unique and sequential, starting from the synced confirmed_nonce (10). + expected_start_nonce = 10 + expected_nonces = SortedSet(range(expected_start_nonce, expected_start_nonce + num_threads * num_tx_per_thread)) + self.assertEqual(pending, expected_nonces) + + # Simulate all transactions being confirmed in order to finalize state + for i in range(expected_start_nonce, expected_start_nonce + num_threads * num_tx_per_thread): + self.nonce_manager.confirm_nonce(wallet, i) + + self.assertEqual(self.nonce_manager.get_confirmed_nonce(wallet), expected_start_nonce + num_threads * num_tx_per_thread) + self.assertEqual(self.nonce_manager.get_pending_nonces(wallet), SortedSet()) + self.assertEqual(self.nonce_manager.get_total_pending_transactions(wallet), 0) + + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_payment_protocol.py b/tests/test_payment_protocol.py new file mode 100644 index 0000000..c904e01 --- /dev/null +++ b/tests/test_payment_protocol.py @@ -0,0 +1,318 @@ +""" +Unit tests for Agent-to-Agent Payment Protocol + +Run with: pytest tests/test_payment_protocol.py -v + +Uses mock chain state to test: +- Payment creation with funds locked +- Confirmation flow +- Timeout/refund flow +- Cancellation +- Nonce management +""" + +import pytest +import time +from unittest.mock import MagicMock, patch, PropertyMock +from decimal import Decimal + +# ─── Mock Web3 ───────────────────────────────────────────────────────────── + +class MockWeb3: + def __init__(self): + self.eth = MockEth() + self.to_checksum_address = lambda a: a.lower() + +class MockEth: + def __init__(self): + self.gas_price = 20000000000 # 20 gwei + self.block_number = 1000 + self.get_transaction_count = lambda addr: 1 + self.get_balance = lambda addr: 10**21 # 1000 ETH + self.wait_for_transaction_receipt = MagicMock(return_value=MagicMock(status=1)) + self.send_raw_transaction = MagicMock(return_value=b'\x00' * 32) + +class MockAccount: + def __init__(self, addr="0x742d35Cc6634C0532925a3b844Bc9e7595f"): + self.address = addr + + @staticmethod + def from_key(key): + return MockAccount() + + def sign_transaction(self, tx): + return MagicMock(raw_transaction=b'\x00' * 32) + +# ─── Mock Contract ───────────────────────────────────────────────────────── + +class MockContract: + def __init__(self): + self.address = "0x1234567890123456789012345678901234567890" + self.payments = {} # requestId → payment data + + def functions(self): + return MockContractFunctions(self) + + def/events(self): + return MockEvents() + + +class MockContractFunctions: + def __init__(self, contract): + self.contract = contract + + def createPayment(self, requestId, payee, timeoutBlocks, challengePeriod): + return MockFn("createPayment", self.contract, requestId, payee, timeoutBlocks, challengePeriod) + + def confirmPayment(self, requestId): + return MockFn("confirmPayment", self.contract, requestId) + + def requestRefund(self, requestId): + return MockFn("requestRefund", self.contract, requestId) + + def cancelPayment(self, requestId): + return MockFn("cancelPayment", self.contract, requestId) + + def getPayment(self, requestId): + return MockFn("getPayment", self.contract, requestId) + + def isExpired(self, requestId): + return MockFn("isExpired", self.contract, requestId) + + +class MockFn: + def __init__(self, name, contract, *args): + self.name = name + self.contract = contract + self.args = args + + def build_transaction(self, tx_params): + return { + 'to': self.contract.address, + 'data': '0x', + **tx_params + } + + def call(self): + if self.name == "createPayment": + req_id = self.args[0] + self.contract.payments[req_id] = { + 'state': 1, # Locked + 'amount': 1000000000000000000, # 1 ETH + 'createdAt': 1000 + } + return True + elif self.name == "getPayment": + req_id = self.args[0] + if req_id in self.contract.payments: + p = self.contract.payments[req_id] + return [ + "0x742d35Cc6634C0532925a3b844Bc9e7595f", # payer + "0x853d955aCEf822Db058eb8505911ED77F175b99e", # payee + p['amount'], + 100, # timeout_blocks + 10, # challenge_period + p['state'], + req_id, + p['createdAt'] + ] + return ["", "", 0, 0, 0, 0, "", 0] + elif self.name == "isExpired": + req_id = self.args[0] + if req_id in self.contract.payments: + return self.contract.payments[req_id]['state'] == 1 # Locked + return False + return None + + +# ─── Import and test the module ──────────────────────────────────────────── + +def test_payment_request_creation(): + """Test that PaymentRequest dataclass works correctly""" + from src.payment_protocol import PaymentRequest + + req = PaymentRequest( + request_id="test-123", + payer="0x742d35Cc6634C0532925a3b844Bc9e7595f", + payee="0x853d955aCEf822Db058eb8505911ED77F175b99e", + amount_wei=10**18, + timeout_blocks=100, + challenge_period_blocks=10, + description="Test payment", + currency="ETH" + ) + + assert req.request_id == "test-123" + assert req.amount_wei == 10**18 + assert req.timeout_blocks == 100 + assert req.status == "pending" + assert req.currency == "ETH" + + # Test JSON serialization + json_str = req.to_json() + assert "test-123" in json_str + assert "locked" not in json_str # status should be "pending" + + # Test content hash + h = req.content_hash() + assert h.startswith("0x") + assert len(h) == 66 # 0x + 64 hex chars + + +def test_payment_request_from_dict(): + """Test deserialization from dict""" + from src.payment_protocol import PaymentRequest + + d = { + "version": "1.0", + "request_id": "test-456", + "payer": "0x742d35Cc6634C0532925a3b844Bc9e7595f", + "payee": "0x853d955aCEf822Db058eb8505911ED77F175b99e", + "amount_wei": 500000000000000000, + "amount_usd": "50.00", + "currency": "ETH", + "chain_id": 1, + "timeout_blocks": 100, + "challenge_period_blocks": 10, + "description": "Test", + "metadata": {"order_id": "123"}, + "created_at": 1234567890.0, + "status": "locked" + } + + req = PaymentRequest.from_dict(d) + assert req.request_id == "test-456" + assert req.amount_usd == Decimal("50.00") + assert req.status == "locked" + + +def test_format_wei(): + """Test wei formatting""" + from src.payment_protocol import format_wei + + assert "1.000000 ETH" in format_wei(10**18) + assert "0.500000 ETH" in format_wei(5 * 10**17) + assert "0.000001 ETH" in format_wei(10**12) + + +def test_parse_wei(): + """Test parsing human-readable amounts to wei""" + from src.payment_protocol import parse_wei + + assert parse_wei("1 ETH") == 10**18 + assert parse_wei("0.5 ETH") == 5 * 10**17 + assert parse_wei("1000000000000000000 wei") == 10**18 + assert parse_wei("1.5 ETH") == int(Decimal("1.5") * Decimal(10**18)) + + +def test_payment_state_enum(): + """Test PaymentState enum""" + from src.payment_protocol import PaymentState + + assert PaymentState.LOCKED.value == "locked" + assert PaymentState.RELEASED.value == "released" + assert PaymentState.REFUNDED.value == "refunded" + + +def test_content_hash_deterministic(): + """Test that content_hash is deterministic""" + from src.payment_protocol import PaymentRequest + + req1 = PaymentRequest( + request_id="det-test", + payer="0x742d35Cc6634C0532925a3b844Bc9e7595f", + payee="0x853d955aCEf822Db058eb8505911ED77F175b99e", + amount_wei=10**18, + currency="ETH" + ) + + req2 = PaymentRequest( + request_id="det-test", + payer="0x742d35Cc6634C0532925a3b844Bc9e7595f", + payee="0x853d955aCEf822Db058eb8505911ED77F175b99e", + amount_wei=10**18, + currency="ETH" + ) + + assert req1.content_hash() == req2.content_hash() + + # Different content → different hash + req3 = PaymentRequest( + request_id="det-test-CHANGED", + payer="0x742d35Cc6634C0532925a3b844Bc9e7595f", + payee="0x853d955aCEf822Db058eb8505911ED77F175b99e", + amount_wei=10**18, + currency="ETH" + ) + assert req1.content_hash() != req3.content_hash() + + +def test_mock_contract_create(): + """Test mock contract payment creation flow""" + contract = MockContract() + fns = contract.functions() + + result = fns.createPayment("req-001", "0xPayee", 100, 10).call() + assert result == True + assert "req-001" in contract.payments + assert contract.payments["req-001"]["state"] == 1 # Locked + + +def test_payment_lifecycle(): + """Test full payment lifecycle: create → confirm → released""" + contract = MockContract() + + # Create + contract.functions().createPayment("req-002", "0xPayee", 100, 10).call() + payment = contract.functions().getPayment("req-002").call() + assert payment[5] == 1 # Locked state + + # Confirm (would release funds) + contract.payments["req-002"]["state"] = 3 # Released + payment_after = contract.functions().getPayment("req-002").call() + assert payment_after[5] == 3 # Released state + + +def test_timeout_and_refund(): + """Test timeout → refund flow""" + contract = MockContract() + + # Create payment with very short timeout (would be expired in mock) + contract.functions().createPayment("req-003", "0xPayee", 1, 10).call() + + # Simulate expired payment + contract.payments["req-003"]["state"] = 1 # Still locked but should be expired + is_expired = contract.functions().isExpired("req-003").call() + assert is_expired == True + + # After challenge period → refund + contract.payments["req-003"]["state"] = 4 # Refunded + payment = contract.functions().getPayment("req-003").call() + assert payment[5] == 4 # Refunded state + + +def test_payment_metadata(): + """Test that arbitrary metadata can be stored with payment""" + from src.payment_protocol import PaymentRequest + + req = PaymentRequest( + request_id="meta-test", + payer="0x742d35Cc6634C0532925a3b844Bc9e7595f", + payee="0x853d955aCEf822Db058eb8505911ED77F175b99e", + amount_wei=10**18, + metadata={ + "order_id": "ORD-12345", + "service": "code-review", + "tags": ["solidity", "audit"], + "priority": "high" + } + ) + + d = req.to_dict() + assert d["metadata"]["order_id"] == "ORD-12345" + assert d["metadata"]["service"] == "code-review" + assert "solidity" in d["metadata"]["tags"] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_x402_middleware.py b/tests/test_x402_middleware.py new file mode 100644 index 0000000..c543b50 --- /dev/null +++ b/tests/test_x402_middleware.py @@ -0,0 +1,207 @@ +"""Tests for x402 server-side middleware.""" + +import asyncio +import json +import pytest +import time +from decimal import Decimal +from unittest.mock import MagicMock, AsyncMock, patch + +from switchboard.x402_middleware import ( + X402Middleware, + PaymentOffer, + PaymentProof, + PaymentScheme, + PaymentRecord, +) + + +# ─── PaymentOffer Tests ────────────────────────────────────────────────────── + +class TestPaymentOffer: + def test_from_header_minimal(self): + header = json.dumps({"amount": "1000000", "recipient": "0xABC", "chainId": 8453}) + offer = PaymentOffer.from_header(header, endpoint="/api/infer") + assert offer.amount_wei == 1000000 + assert offer.recipient == "0xABC" + assert offer.chain_id == 8453 + assert offer.scheme == PaymentScheme.EXACT + assert offer.endpoint == "/api/infer" + + def test_from_header_full(self): + header = json.dumps({ + "amount": "5000000000000000", + "recipient": "0xDEF", + "chainId": 1, + "currency": "ETH", + "scheme": "escrow", + "description": "inference job", + "nonce": "abc123", + "expiresAt": int(time.time()) + 3600, + }) + offer = PaymentOffer.from_header(header) + assert offer.scheme == PaymentScheme.ESCROW + assert offer.nonce == "abc123" + assert not offer.is_expired() + + def test_expired_offer(self): + offer = PaymentOffer( + amount_wei=100, + currency="ETH", + recipient="0x1", + chain_id=1, + expires_at=int(time.time()) - 10, + ) + assert offer.is_expired() + + def test_no_expiry_not_expired(self): + offer = PaymentOffer( + amount_wei=100, currency="ETH", recipient="0x1", chain_id=1 + ) + assert not offer.is_expired() + + +# ─── PaymentProof Tests ────────────────────────────────────────────────────── + +class TestPaymentProof: + def test_to_header_roundtrip(self): + proof = PaymentProof( + tx_hash="0xabc", + chain_id=8453, + payer="0x123", + amount_wei=1000000, + nonce="n1", + ) + header = proof.to_header() + data = json.loads(header) + assert data["txHash"] == "0xabc" + assert data["chainId"] == 8453 + assert data["payer"] == "0x123" + assert data["amount"] == 1000000 + + +# ─── Middleware Validation Tests ───────────────────────────────────────────── + +class TestMiddlewareValidation: + def _make_middleware(self, **kwargs): + client = MagicMock() + client.wallet_address = "0xPAYER" + return X402Middleware(payment_client=client, **kwargs) + + def test_rejects_expired_offer(self): + mw = self._make_middleware() + offer = PaymentOffer( + amount_wei=100, currency="ETH", recipient="0x1", + chain_id=1, expires_at=int(time.time()) - 1, + ) + with pytest.raises(ValueError, match="expired"): + mw._validate_offer(offer) + + def test_rejects_over_cap(self): + mw = self._make_middleware(max_payment_wei=1000) + offer = PaymentOffer( + amount_wei=2000, currency="ETH", recipient="0x1", chain_id=1, + ) + with pytest.raises(ValueError, match="exceeds cap"): + mw._validate_offer(offer) + + def test_rejects_unknown_recipient(self): + mw = self._make_middleware(allowed_recipients={"0xGOOD"}) + offer = PaymentOffer( + amount_wei=100, currency="ETH", recipient="0xBAD", chain_id=1, + ) + with pytest.raises(ValueError, match="not in allowlist"): + mw._validate_offer(offer) + + def test_accepts_valid_offer(self): + mw = self._make_middleware( + max_payment_wei=10**18, + allowed_recipients={"0xGOOD"}, + ) + offer = PaymentOffer( + amount_wei=10**15, currency="ETH", recipient="0xGOOD", chain_id=1, + ) + mw._validate_offer(offer) # Should not raise + + def test_rejects_over_gas_budget(self): + tracker = MagicMock() + tracker.can_send_transaction.return_value = False + mw = self._make_middleware(gas_tracker=tracker) + offer = PaymentOffer( + amount_wei=100, currency="ETH", recipient="0x1", chain_id=1, + ) + with pytest.raises(ValueError, match="gas budget"): + mw._validate_offer(offer) + + +# ─── Payment Execution Tests ──────────────────────────────────────────────── + +class TestPaymentExecution: + def test_exact_payment(self): + client = MagicMock() + client.wallet_address = "0xPAYER" + client.sign_and_send.return_value = "0xTXHASH" + client.wait_for_confirmations.return_value = {"status": 1} + + mw = X402Middleware(payment_client=client) + offer = PaymentOffer( + amount_wei=10**15, currency="ETH", + recipient="0xRECIPIENT", chain_id=8453, nonce="n1", + ) + + proof = mw._pay_onchain(offer) + assert proof.tx_hash == "0xTXHASH" + assert proof.payer == "0xPAYER" + assert proof.chain_id == 8453 + + client.sign_and_send.assert_called_once() + tx_arg = client.sign_and_send.call_args[0][0] + assert tx_arg["to"] == "0xRECIPIENT" + assert tx_arg["value"] == 10**15 + + def test_escrow_payment(self): + client = MagicMock() + client.wallet_address = "0xPAYER" + mock_req = MagicMock() + mock_req.request_id = "req-123" + client.create_payment.return_value = mock_req + + mw = X402Middleware(payment_client=client) + offer = PaymentOffer( + amount_wei=10**15, currency="ETH", + recipient="0xRECIPIENT", chain_id=1, + scheme=PaymentScheme.ESCROW, + ) + + proof = mw._pay_onchain(offer) + assert proof.tx_hash == "req-123" + client.create_payment.assert_called_once() + + +# ─── Spend Summary Tests ──────────────────────────────────────────────────── + +class TestSpendSummary: + def test_empty_summary(self): + client = MagicMock() + client.wallet_address = "0x1" + mw = X402Middleware(payment_client=client) + summary = mw.get_spend_summary() + assert summary["total_payments"] == 0 + assert summary["total_spent_wei"] == 0 + + def test_tracks_payments(self): + client = MagicMock() + client.wallet_address = "0x1" + mw = X402Middleware(payment_client=client) + + offer = PaymentOffer(amount_wei=1000, currency="ETH", recipient="0x2", chain_id=1) + proof = PaymentProof(tx_hash="0xa", chain_id=1, payer="0x1", amount_wei=1000) + mw.payment_history.append(PaymentRecord( + endpoint="/api/a", offer=offer, proof=proof, response_status=200, + )) + mw.total_spent_wei = 1000 + + summary = mw.get_spend_summary() + assert summary["total_payments"] == 1 + assert summary["total_spent_wei"] == 1000 + assert summary["by_endpoint"]["/api/a"] == 1000 diff --git a/tests/test_zap_transport.py b/tests/test_zap_transport.py new file mode 100644 index 0000000..716917f --- /dev/null +++ b/tests/test_zap_transport.py @@ -0,0 +1,186 @@ +"""Roundtrip tests for switchboard.zap_transport. + +Skipped if zap_py is not installed — install with:: + + pip install 'luxfi-zap @ git+https://github.com/luxfi/zap@main#subdirectory=python' +""" + +from __future__ import annotations + +import time + +import pytest + +zap_py = pytest.importorskip("zap_py") + +from switchboard.x402_middleware import PaymentOffer, PaymentProof, PaymentScheme +from switchboard import zap_transport as zt + + +VITALIK = "0xd8da6bf26964af9d7eed9e03e53415d37aa96045" +USDC = "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48" +SAMPLE_TX = "0x" + "ab" * 32 + + +def test_offer_schema_layout(): + """Schema is well-formed and has the expected fields.""" + fields = {f.name for f in zt.OFFER_SCHEMA.fields} + assert fields == { + "scheme", + "chain_id", + "expires_at", + "recipient", + "amount", + "currency", + "description", + "endpoint", + "nonce", + } + # Final align(8); nine fields, smallest start. + assert zt.OFFER_SCHEMA.size > 0 + assert zt.OFFER_SCHEMA.size % 8 == 0 + + +def test_offer_roundtrip_minimal(): + offer = PaymentOffer( + amount_wei=1_000_000, + currency="USDC", + recipient=USDC, + chain_id=8453, + ) + wire = zt.encode_offer(offer) + assert isinstance(wire, bytes) + out = zt.decode_offer(wire) + + assert out.amount_wei == offer.amount_wei + assert out.currency == offer.currency + assert out.recipient.lower() == offer.recipient.lower() + assert out.chain_id == offer.chain_id + assert out.scheme == PaymentScheme.EXACT # default + assert out.description == "" + assert out.endpoint == "" + assert out.nonce == "" + assert out.expires_at is None # 0 sentinel restored to None + + +def test_offer_roundtrip_full(): + offer = PaymentOffer( + amount_wei=12_345_678_901_234_567_890_123, # > uint64 + currency="ETH", + recipient=VITALIK, + chain_id=1, + scheme=PaymentScheme.ESCROW, + description="inference call to /v1/embed", + endpoint="/v1/embed", + nonce="abc-xyz-123", + expires_at=int(time.time()) + 60, + ) + wire = zt.encode_offer(offer) + out = zt.decode_offer(wire) + + assert out.amount_wei == offer.amount_wei + assert out.currency == offer.currency + assert out.recipient.lower() == offer.recipient.lower() + assert out.chain_id == offer.chain_id + assert out.scheme == offer.scheme + assert out.description == offer.description + assert out.endpoint == offer.endpoint + assert out.nonce == offer.nonce + assert out.expires_at == offer.expires_at + + +def test_offer_each_scheme_roundtrips(): + for scheme in PaymentScheme: + offer = PaymentOffer(amount_wei=1, currency="ETH", recipient=USDC, chain_id=1, scheme=scheme) + wire = zt.encode_offer(offer) + out = zt.decode_offer(wire) + assert out.scheme == scheme + + +def test_offer_amount_uint256_max(): + """Wire format must accept a full uint256 amount, not truncate to uint64.""" + big = (1 << 256) - 1 + offer = PaymentOffer(amount_wei=big, currency="ETH", recipient=USDC, chain_id=1) + wire = zt.encode_offer(offer) + out = zt.decode_offer(wire) + assert out.amount_wei == big + + +def test_offer_amount_overflow_rejected(): + offer = PaymentOffer(amount_wei=1 << 256, currency="ETH", recipient=USDC, chain_id=1) + with pytest.raises(ValueError): + zt.encode_offer(offer) + + +def test_offer_negative_amount_rejected(): + offer = PaymentOffer(amount_wei=-1, currency="ETH", recipient=USDC, chain_id=1) + with pytest.raises(ValueError): + zt.encode_offer(offer) + + +def test_proof_roundtrip(): + proof = PaymentProof( + tx_hash=SAMPLE_TX, + chain_id=8453, + payer=VITALIK, + amount_wei=999_999, + nonce="proof-1", + timestamp=1714421000.0, + ) + wire = zt.encode_proof(proof) + out = zt.decode_proof(wire) + + assert out.tx_hash == SAMPLE_TX + assert out.chain_id == proof.chain_id + assert out.payer.lower() == proof.payer.lower() + assert out.amount_wei == proof.amount_wei + assert out.nonce == proof.nonce + assert out.timestamp == 1714421000.0 + + +def test_proof_invalid_tx_hash_length(): + proof = PaymentProof( + tx_hash="0xdead", # too short + chain_id=1, + payer=VITALIK, + amount_wei=1, + ) + with pytest.raises(ValueError): + zt.encode_proof(proof) + + +def test_zap_not_available_raises_when_disabled(monkeypatch): + """If zap_py is force-disabled, encode/decode must raise ZapNotAvailable.""" + monkeypatch.setattr(zt, "HAS_ZAP_PY", False) + offer = PaymentOffer(amount_wei=1, currency="ETH", recipient=USDC, chain_id=1) + with pytest.raises(zt.ZapNotAvailable): + zt.encode_offer(offer) + with pytest.raises(zt.ZapNotAvailable): + zt.decode_offer(b"junk") + + +def test_offer_wire_smaller_than_json(): + """Sanity: ZAP encoding is genuinely tighter than the JSON header path.""" + import json + + offer = PaymentOffer( + amount_wei=1_000_000, + currency="USDC", + recipient=USDC, + chain_id=8453, + nonce="n-1", + endpoint="/v1/x", + ) + wire = zt.encode_offer(offer) + json_blob = json.dumps({ + "amount": str(offer.amount_wei), + "currency": offer.currency, + "recipient": offer.recipient, + "chainId": offer.chain_id, + "scheme": offer.scheme.value, + "endpoint": offer.endpoint, + "nonce": offer.nonce, + }).encode() + # Not an absolute guarantee (small offers may bloat with the ZAP header), + # but for any realistic offer the binary form should win. + assert len(wire) <= len(json_blob) + 64 # tolerance for fixed ZAP header diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000..a42e385 --- /dev/null +++ b/web/README.md @@ -0,0 +1,38 @@ +# switchboard web dashboard + +Zero-build interactive map of the 2026 agent-payment rails — x402, MPP, +AP2, Circle Nanopayments — and how switchboard's on-chain escrow fits +alongside them. + +## What it shows + +- **Protocol flow** — a sequence diagram per protocol with 4 numbered + steps. Click a step to see the wire-level snippet (HTTP, JSON body, + Solidity call). +- **Compatibility matrix** — side-by-side: transport, settlement asset, + agent↔agent vs agent↔server, streaming / sessions, disputes, fiat + rails, license. +- **How switchboard fits** — the gap these rails leave (agent-side keys, + nonces, budgets, escrow) and what this repo provides. + +## Run locally + +```bash +python3 -m http.server -d web 8080 +``` + +## Hosted + +- kcolbchain.com/switchboard/ + +## Source references + +Protocol summaries reflect April 2026 state: + +- x402 joined the Linux Foundation 2026-04-02 (Coinbase, Google, AWS, + Microsoft, Stripe, Visa, Mastercard as founding members). +- MPP (Stripe × Paradigm × Tempo) went live on Tempo L1 mainnet 2026-03-18. +- Circle Nanopayments testnet opened 2026-03-03. +- Google Cloud announced AP2 the same week. +- Mastercard's Verifiable Intent specification (open-sourced on GitHub) + complements AP2. diff --git a/web/index.html b/web/index.html new file mode 100644 index 0000000..8261c84 --- /dev/null +++ b/web/index.html @@ -0,0 +1,420 @@ + + + + + +switchboard — agent payments infrastructure + + + + + +
+
+

switchboard. agent payments infrastructure

+

Agent wallets, gas budgets, agent-to-agent escrow. Compatible with the open agent-payment rails of 2026: x402, MPP, AP2, Circle Nanopayments.

+
+
+ source + docs +
+
+ +
+ +

Protocol flow

+
+ + + + + +
+ +
+
+
+

+
+
+
+
+

Steps

+
    +
    +
    + +
    +
    + + + + + + + +
    +
    +

    Reference code — current step

    +
    Click a step on the left to see the wire-level snippet.
    +
    +
    +
    + +

    Compatibility matrix

    +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    x402MPPAP2Circle Nanoswitchboard escrow
    TransportHTTP header on 402HTTP + Tempo txsA2A over gRPC/JSON-RPCHTTP/SDK, off-chain+onchainHTTP + chain txs (escrow)
    Settlement assetUSDC (multichain)USDC on Tempo, fiat via SPTprotocol-agnosticUSDC nanopaymentsnative ETH/token + USDC
    Agent ↔ Agentindirect (server intermediary)nativenativevia Circle Walletsnative
    Agent ↔ API / MCP serverprimary use casesupportedvia payments facilitatornativevia wrapper
    Streaming / sessionsper-request (v2 adds multi-step)streamed micro-txs per sessionyes (verifiable intent mandate)nanopayments < $0.000001timeout + challenge-period escrow
    Dispute / refundnone (idempotent retry)SPT + card dispute railsAP2 verifiable intent → issuerCircle policy + ruleson-chain timeout + refund
    Fiat railscrypto onlyvia Stripe SPTcard networks (Visa, MC)via Circle on/off rampcrypto only
    LicenseApache-2.0Apache-2.0Apache-2.0proprietary SDK + open skillsMIT
    +
    + +

    How switchboard fits

    +
    +

    The open rails above settle agent payments; they don't manage the agent side — keys, nonces, gas budgets, counterparty escrow. switchboard fills that gap:

    +
      +
    • Agent wallets — MPC key management for autonomous actors (tracked in #1).
    • +
    • Client-side nonce manager with reorg protection — so a bursty agent doesn't brick its own queue.
    • +
    • Gas budget tracker — rolling hour/day limits, auto-pause on exhaustion (#14, merged).
    • +
    • Agent-to-agent escrowAgentEscrow.sol with timeout + challenge-period refund (#2, merged).
    • +
    • x402 / MPP / AP2 server middleware — so switchboard-managed agents can respond to 402 Payment Required and initiate MPP sessions out of the box (tokenomics / integration issues incoming).
    • +
    +
    + +
    + + + + + +