From 3b4d2c0ccd5a39d16c66363c1897af03d3a7f521 Mon Sep 17 00:00:00 2001 From: Jonathan Segev Date: Fri, 16 May 2025 12:07:52 -0400 Subject: [PATCH] models - anthropic Co-authored-by: Patrick Gray Co-authored-by: Jason Kim --- pyproject.toml | 7 +- src/strands/models/anthropic.py | 357 +++++++++++++ tests-integ/test_model_anthropic.py | 48 ++ tests/strands/models/test_anthropic.py | 669 +++++++++++++++++++++++++ 4 files changed, 1079 insertions(+), 2 deletions(-) create mode 100644 src/strands/models/anthropic.py create mode 100644 tests-integ/test_model_anthropic.py create mode 100644 tests/strands/models/test_anthropic.py diff --git a/pyproject.toml b/pyproject.toml index 5e43b418a..0d920f1d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ Documentation = "https://strandsagents.com" packages = ["src/strands"] [project.optional-dependencies] +anthropic = [ + "anthropic>=0.21.0,<1.0.0", +] dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", @@ -71,7 +74,7 @@ ollama = [ ] [tool.hatch.envs.hatch-static-analysis] -features = ["litellm", "ollama"] +features = ["anthropic", "litellm", "ollama"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -94,7 +97,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["litellm", "ollama"] +features = ["anthropic", "litellm", "ollama"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py new file mode 100644 index 000000000..704114eb9 --- /dev/null +++ b/src/strands/models/anthropic.py @@ -0,0 +1,357 @@ +"""Anthropic Claude model provider. + +- Docs: https://docs.anthropic.com/claude/reference/getting-started-with-the-api +""" + +import base64 +import json +import logging +import mimetypes +from typing import Any, Iterable, Optional, TypedDict, cast + +import anthropic +from typing_extensions import Required, Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.models import Model +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec + +logger = logging.getLogger(__name__) + + +class AnthropicModel(Model): + """Anthropic model provider implementation.""" + + EVENT_TYPES = { + "message_start", + "content_block_start", + "content_block_delta", + "content_block_stop", + "message_stop", + } + + OVERFLOW_MESSAGES = { + "input is too long", + "input length exceeds context window", + "input and output tokens exceed your context limit", + } + + class AnthropicConfig(TypedDict, total=False): + """Configuration options for Anthropic models. + + Attributes: + max_tokens: Maximum number of tokens to generate. + model_id: Calude model ID (e.g., "claude-3-7-sonnet-latest"). + For a complete list of supported models, see + https://docs.anthropic.com/en/docs/about-claude/models/all-models. + params: Additional model parameters (e.g., temperature). + For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. + """ + + max_tokens: Required[str] + model_id: Required[str] + params: Optional[dict[str, Any]] + + def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[AnthropicConfig]): + """Initialize provider instance. + + Args: + client_args: Arguments for the underlying Anthropic client (e.g., api_key). + For a complete list of supported arguments, see https://docs.anthropic.com/en/api/client-sdks. + **model_config: Configuration options for the Anthropic model. + """ + self.config = AnthropicModel.AnthropicConfig(**model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + self.client = anthropic.Anthropic(**client_args) + + @override + def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # type: ignore[override] + """Update the Anthropic model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> AnthropicConfig: + """Get the Anthropic model configuration. + + Returns: + The Anthropic model configuration. + """ + return self.config + + def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + """Format an Anthropic content block. + + Args: + content: Message content. + + Returns: + Anthropic formatted content block. + """ + if "document" in content: + return { + "source": { + "data": base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8"), + "media_type": mimetypes.types_map.get( + f".{content['document']['format']}", "application/octet-stream" + ), + "type": "base64", + }, + "title": content["document"]["name"], + "type": "document", + } + + if "image" in content: + return { + "source": { + "data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"), + "media_type": mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream"), + "type": "base64", + }, + "type": "image", + } + + if "reasoningContent" in content: + return { + "signature": content["reasoningContent"]["reasoningText"]["signature"], + "thinking": content["reasoningContent"]["reasoningText"]["text"], + "type": "thinking", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + if "toolUse" in content: + return { + "id": content["toolUse"]["toolUseId"], + "input": content["toolUse"]["input"], + "name": content["toolUse"]["name"], + "type": "tool_use", + } + + if "toolResult" in content: + return { + "content": [ + self._format_request_message_content(cast(ContentBlock, tool_result_content)) + for tool_result_content in content["toolResult"]["content"] + ], + "is_error": content["toolResult"]["status"] == "error", + "tool_use_id": content["toolResult"]["toolUseId"], + "type": "tool_result", + } + + return {"text": json.dumps(content), "type": "text"} + + def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: + """Format an Anthropic messages array. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + An Anthropic messages array. + """ + formatted_messages = [] + + for message in messages: + formatted_contents: list[dict[str, Any]] = [] + + for content in message["content"]: + if "cachePoint" in content: + formatted_contents[-1]["cache_control"] = {"type": "ephemeral"} + continue + + formatted_contents.append(self._format_request_message_content(content)) + + if formatted_contents: + formatted_messages.append({"content": formatted_contents, "role": message["role"]}) + + return formatted_messages + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an Anthropic streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An Anthropic streaming request. + """ + return { + "max_tokens": self.config["max_tokens"], + "messages": self._format_request_messages(messages), + "model": self.config["model_id"], + "tools": [ + { + "name": tool_spec["name"], + "description": tool_spec["description"], + "input_schema": tool_spec["inputSchema"]["json"], + } + for tool_spec in tool_specs or [] + ], + **({"system": system_prompt} if system_prompt else {}), + **(self.config.get("params") or {}), + } + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format the Anthropic response events into standardized message chunks. + + Args: + event: A response event from the Anthropic model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as we control chunk_type in the stream method. + """ + match event["type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_block_start": + content = event["content_block"] + + if content["type"] == "tool_use": + return { + "contentBlockStart": { + "contentBlockIndex": event["index"], + "start": { + "toolUse": { + "name": content["name"], + "toolUseId": content["id"], + } + }, + } + } + + return {"contentBlockStart": {"contentBlockIndex": event["index"], "start": {}}} + + case "content_block_delta": + delta = event["delta"] + + match delta["type"]: + case "signature_delta": + return { + "contentBlockDelta": { + "contentBlockIndex": event["index"], + "delta": { + "reasoningContent": { + "signature": delta["signature"], + }, + }, + }, + } + + case "thinking_delta": + return { + "contentBlockDelta": { + "contentBlockIndex": event["index"], + "delta": { + "reasoningContent": { + "text": delta["thinking"], + }, + }, + }, + } + + case "input_json_delta": + return { + "contentBlockDelta": { + "contentBlockIndex": event["index"], + "delta": { + "toolUse": { + "input": delta["partial_json"], + }, + }, + }, + } + + case "text_delta": + return { + "contentBlockDelta": { + "contentBlockIndex": event["index"], + "delta": { + "text": delta["text"], + }, + }, + } + + case _: + raise RuntimeError( + f"event_type=, delta_type=<{delta['type']}> | unknown type" + ) + + case "content_block_stop": + return {"contentBlockStop": {"contentBlockIndex": event["index"]}} + + case "message_stop": + message = event["message"] + + return {"messageStop": {"stopReason": message["stop_reason"]}} + + case "metadata": + usage = event["usage"] + + return { + "metadata": { + "usage": { + "inputTokens": usage["input_tokens"], + "outputTokens": usage["output_tokens"], + "totalTokens": usage["input_tokens"] + usage["output_tokens"], + }, + "metrics": { + "latencyMs": 0, # TODO + }, + } + } + + case _: + raise RuntimeError(f"event_type=<{event['type']} | unknown type") + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the Anthropic model and get the streaming response. + + Args: + request: The formatted request to send to the Anthropic model. + + Returns: + An iterable of response events from the Anthropic model. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by Anthropic. + """ + try: + with self.client.messages.stream(**request) as stream: + for event in stream: + if event.type in AnthropicModel.EVENT_TYPES: + yield event.dict() + + usage = event.message.usage # type: ignore + yield {"type": "metadata", "usage": usage.dict()} + + except anthropic.RateLimitError as error: + raise ModelThrottledException(str(error)) from error + + except anthropic.BadRequestError as error: + if any(overflow_message in str(error).lower() for overflow_message in AnthropicModel.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + + raise error diff --git a/tests-integ/test_model_anthropic.py b/tests-integ/test_model_anthropic.py new file mode 100644 index 000000000..d1f4c7ce3 --- /dev/null +++ b/tests-integ/test_model_anthropic.py @@ -0,0 +1,48 @@ +import os + +import strands +import pytest +from strands import Agent +from strands.models.anthropic import AnthropicModel + + +@pytest.fixture +def model(): + return AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-7-sonnet-20250219", + max_tokens=512, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that uses & instead of ." + + +@pytest.fixture +def agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing") +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny", "&"]) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py new file mode 100644 index 000000000..48a1da372 --- /dev/null +++ b/tests/strands/models/test_anthropic.py @@ -0,0 +1,669 @@ +import json +import unittest.mock + +import anthropic +import pytest + +import strands +from strands.models.anthropic import AnthropicModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def anthropic_client(): + with unittest.mock.patch.object(strands.models.anthropic.anthropic, "Anthropic") as mock_client_cls: + yield mock_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "m1" + + +@pytest.fixture +def max_tokens(): + return 1 + + +@pytest.fixture +def model(anthropic_client, model_id, max_tokens): + _ = anthropic_client + + return AnthropicModel(model_id=model_id, max_tokens=max_tokens) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test__init__model_configs(anthropic_client, model_id, max_tokens): + _ = anthropic_client + + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, params={"temperature": 1}) + + tru_temperature = model.get_config().get("params") + exp_temperature = {"temperature": 1} + + assert tru_temperature == exp_temperature + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_format_request_default(model, messages, model_id, max_tokens): + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_params(model, messages, model_id, max_tokens): + model.update_config(params={"temperature": 1}) + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [], + "temperature": 1, + } + + assert tru_request == exp_request + + +def test_format_request_with_system_prompt(model, messages, model_id, max_tokens, system_prompt): + tru_request = model.format_request(messages, system_prompt=system_prompt) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "system": system_prompt, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_document(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "format": "pdf", + "name": "test-doc", + "source": {"bytes": b"base64encodeddoc"}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [ + { + "role": "user", + "content": [ + { + "source": { + "data": "YmFzZTY0ZW5jb2RlZGRvYw==", + "media_type": "application/pdf", + "type": "base64", + }, + "title": "test-doc", + "type": "document", + }, + ], + }, + ], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_image(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "jpg", + "source": {"bytes": b"base64encodedimage"}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [ + { + "role": "user", + "content": [ + { + "source": { + "data": "YmFzZTY0ZW5jb2RlZGltYWdl", + "media_type": "image/jpeg", + "type": "base64", + }, + "type": "image", + }, + ], + }, + ], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_reasoning(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "signature": "reasoning_signature", + "text": "reasoning_text", + }, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [ + { + "role": "user", + "content": [ + { + "signature": "reasoning_signature", + "thinking": "reasoning_text", + "type": "thinking", + }, + ], + }, + ], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_use(model, model_id, max_tokens): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [ + { + "role": "assistant", + "content": [ + { + "id": "c1", + "input": {"expression": "2+2"}, + "name": "calculator", + "type": "tool_use", + }, + ], + }, + ], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_results(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "c1", + "status": "success", + "content": [ + {"text": "see image"}, + { + "image": { + "format": "jpg", + "source": {"bytes": b"base64encodedimage"}, + }, + }, + ], + } + } + ], + } + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [ + { + "role": "user", + "content": [ + { + "content": [ + { + "text": "see image", + "type": "text", + }, + { + "source": { + "data": "YmFzZTY0ZW5jb2RlZGltYWdl", + "media_type": "image/jpeg", + "type": "base64", + }, + "type": "image", + }, + ], + "is_error": False, + "tool_use_id": "c1", + "type": "tool_result", + }, + ], + }, + ], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_other(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [{"other": {"a": 1}}], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [ + { + "role": "user", + "content": [ + { + "text": json.dumps({"other": {"a": 1}}), + "type": "text", + }, + ], + }, + ], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_cache_point(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [ + {"text": "cache me"}, + {"cachePoint": {"type": "default"}}, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [ + { + "role": "user", + "content": [ + { + "cache_control": {"type": "ephemeral"}, + "text": "cache me", + "type": "text", + }, + ], + }, + ], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_request_with_empty_content(model, model_id, max_tokens): + messages = [ + { + "role": "user", + "content": [], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "max_tokens": max_tokens, + "messages": [], + "model": model_id, + "tools": [], + } + + assert tru_request == exp_request + + +def test_format_chunk_message_start(model): + event = {"type": "message_start"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStart": {"role": "assistant"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_start_tool_use(model): + event = { + "content_block": { + "id": "c1", + "name": "calculator", + "type": "tool_use", + }, + "index": 0, + "type": "content_block_start", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockStart": { + "contentBlockIndex": 0, + "start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_start_other(model): + event = { + "content_block": { + "type": "text", + }, + "index": 0, + "type": "content_block_start", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockStart": { + "contentBlockIndex": 0, + "start": {}, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_delta_signature_delta(model): + event = { + "delta": { + "type": "signature_delta", + "signature": "s1", + }, + "index": 0, + "type": "content_block_delta", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockDelta": { + "contentBlockIndex": 0, + "delta": { + "reasoningContent": { + "signature": "s1", + }, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_delta_thinking_delta(model): + event = { + "delta": { + "type": "thinking_delta", + "thinking": "t1", + }, + "index": 0, + "type": "content_block_delta", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockDelta": { + "contentBlockIndex": 0, + "delta": { + "reasoningContent": { + "text": "t1", + }, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_delta_input_json_delta_delta(model): + event = { + "delta": { + "type": "input_json_delta", + "partial_json": "{", + }, + "index": 0, + "type": "content_block_delta", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockDelta": { + "contentBlockIndex": 0, + "delta": { + "toolUse": { + "input": "{", + }, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_delta_text_delta(model): + event = { + "delta": { + "type": "text_delta", + "text": "hello", + }, + "index": 0, + "type": "content_block_delta", + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "contentBlockDelta": { + "contentBlockIndex": 0, + "delta": {"text": "hello"}, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_block_delta_unknown(model): + event = { + "delta": { + "type": "unknown", + }, + "type": "content_block_delta", + } + + with pytest.raises(RuntimeError, match="chunk_type=, delta= | unknown type"): + model.format_chunk(event) + + +def test_format_chunk_content_block_stop(model): + event = {"type": "content_block_stop", "index": 0} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStop": {"contentBlockIndex": 0}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop(model): + event = {"type": "message_stop", "message": {"stop_reason": "end_turn"}} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "end_turn"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_metadata(model): + event = { + "type": "metadata", + "usage": {"input_tokens": 1, "output_tokens": 2}, + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "metadata": { + "usage": { + "inputTokens": 1, + "outputTokens": 2, + "totalTokens": 3, + }, + "metrics": { + "latencyMs": 0, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown(model): + event = {"type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model.format_chunk(event) + + +def test_stream(anthropic_client, model): + mock_event_1 = unittest.mock.Mock(type="message_start", dict=lambda: {"type": "message_start"}) + mock_event_2 = unittest.mock.Mock(type="unknown") + mock_event_3 = unittest.mock.Mock( + type="metadata", message=unittest.mock.Mock(usage=unittest.mock.Mock(dict=lambda: {"input_tokens": 1})) + ) + + mock_stream = unittest.mock.MagicMock() + mock_stream.__iter__.return_value = iter([mock_event_1, mock_event_2, mock_event_3]) + anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream + + request = {"model": "m1"} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"type": "message_start"}, + {"type": "metadata", "usage": {"input_tokens": 1}}, + ] + + assert tru_events == exp_events + anthropic_client.messages.stream.assert_called_once_with(**request) + + +def test_stream_rate_limit_error(anthropic_client, model): + anthropic_client.messages.stream.side_effect = anthropic.RateLimitError( + "rate limit", response=unittest.mock.Mock(), body=None + ) + + with pytest.raises(ModelThrottledException, match="rate limit"): + next(model.stream({})) + + +@pytest.mark.parametrize( + "overflow_message", + [ + "...input is too long...", + "...input length exceeds context window...", + "...input and output tokens exceed your context limit...", + ], +) +def test_stream_bad_request_overflow_error(overflow_message, anthropic_client, model): + anthropic_client.messages.stream.side_effect = anthropic.BadRequestError( + overflow_message, response=unittest.mock.Mock(), body=None + ) + + with pytest.raises(ContextWindowOverflowException): + next(model.stream({})) + + +def test_stream_bad_request_error(anthropic_client, model): + anthropic_client.messages.stream.side_effect = anthropic.BadRequestError( + "bad", response=unittest.mock.Mock(), body=None + ) + + with pytest.raises(anthropic.BadRequestError, match="bad"): + next(model.stream({}))