From de49b2a17672dee7d35d4d5c8bf4cf4c9e615a46 Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:39:08 +0000 Subject: [PATCH] feat(media): add S3 location support to image, document, and video content types - Add S3Location TypedDict with uri and bucketOwner fields to media.py - Update ImageSource, DocumentSource, VideoSource to support s3Location - Update Bedrock provider to format s3Location in content blocks - Add S3 source detection and filtering to non-Bedrock providers: Anthropic, OpenAI, Gemini, LiteLLM, Ollama, Mistral, LlamaAPI, LlamaCpp, SageMaker, Writer - Non-Bedrock providers log warning and skip S3 content blocks - Add comprehensive unit tests for new functionality Resolves #10 --- src/strands/models/anthropic.py | 25 ++++++ src/strands/models/bedrock.py | 32 ++++++-- src/strands/models/gemini.py | 48 +++++++++--- src/strands/models/llamaapi.py | 36 +++++++-- src/strands/models/llamacpp.py | 36 +++++++-- src/strands/models/mistral.py | 25 ++++++ src/strands/models/ollama.py | 36 +++++++-- src/strands/models/openai.py | 37 +++++++-- src/strands/models/writer.py | 34 ++++++++- src/strands/types/media.py | 34 ++++++++- tests/strands/models/test_anthropic.py | 90 ++++++++++++++++++++++ tests/strands/models/test_bedrock.py | 102 ++++++++++++++++++++++++- tests/strands/models/test_openai.py | 87 +++++++++++++++++++++ tests/strands/types/test_media.py | 102 +++++++++++++++++++++++++ 14 files changed, 680 insertions(+), 44 deletions(-) create mode 100644 tests/strands/types/test_media.py diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 535c820ee..2764775c9 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -180,6 +180,7 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: An Anthropic messages array. """ formatted_messages = [] + s3_source_warned = False for message in messages: formatted_contents: list[dict[str, Any]] = [] @@ -189,6 +190,13 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: formatted_contents[-1]["cache_control"] = {"type": "ephemeral"} continue + # Check for S3 sources in image, document, or video content + if self._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by Anthropic, skipping content block") + s3_source_warned = True + continue + formatted_contents.append(self._format_request_message_content(content)) if formatted_contents: @@ -196,6 +204,23 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: return formatted_messages + def _has_s3_source(self, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False + def format_request( self, messages: Messages, diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 567a2e147..da5b3d2e9 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -439,9 +439,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "format" in document: result["format"] = document["format"] - # Handle source + # Handle source - supports bytes or s3Location if "source" in document: - result["source"] = {"bytes": document["source"]["bytes"]} + source = document["source"] + if "s3Location" in source: + s3_loc = source["s3Location"] + formatted_s3: dict[str, Any] = {"uri": s3_loc["uri"]} + if "bucketOwner" in s3_loc: + formatted_s3["bucketOwner"] = s3_loc["bucketOwner"] + result["source"] = {"s3Location": formatted_s3} + elif "bytes" in source: + result["source"] = {"bytes": source["bytes"]} # Handle optional fields if "citations" in document and document["citations"] is not None: @@ -462,8 +470,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "image" in content: image = content["image"] source = image["source"] - formatted_source = {} - if "bytes" in source: + formatted_source: dict[str, Any] = {} + if "s3Location" in source: + s3_loc = source["s3Location"] + formatted_s3: dict[str, Any] = {"uri": s3_loc["uri"]} + if "bucketOwner" in s3_loc: + formatted_s3["bucketOwner"] = s3_loc["bucketOwner"] + formatted_source = {"s3Location": formatted_s3} + elif "bytes" in source: formatted_source = {"bytes": source["bytes"]} result = {"format": image["format"], "source": formatted_source} return {"image": result} @@ -527,8 +541,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "video" in content: video = content["video"] source = video["source"] - formatted_source = {} - if "bytes" in source: + formatted_source: dict[str, Any] = {} + if "s3Location" in source: + s3_loc = source["s3Location"] + formatted_s3: dict[str, Any] = {"uri": s3_loc["uri"]} + if "bucketOwner" in s3_loc: + formatted_s3["bucketOwner"] = s3_loc["bucketOwner"] + formatted_source = {"s3Location": formatted_s3} + elif "bytes" in source: formatted_source = {"bytes": source["bytes"]} result = {"format": video["format"], "source": formatted_source} return {"video": result} diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 855e1ef5c..2c8489c7f 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -228,16 +228,46 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten # Gemini FunctionResponses are constructed from tool result blocks. Function name is required but is not # available in tool result blocks, hence the mapping. tool_use_id_to_name: dict[str, str] = {} + s3_source_warned = False + + contents = [] + for message in messages: + parts = [] + for content in message["content"]: + # Check for S3 sources and skip with warning + if self._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by Gemini, skipping content block") + s3_source_warned = True + continue + parts.append(self._format_request_content_part(content, tool_use_id_to_name)) + + if parts: + contents.append( + genai.types.Content( + parts=parts, + role="user" if message["role"] == "user" else "model", + ) + ) - return [ - genai.types.Content( - parts=[ - self._format_request_content_part(content, tool_use_id_to_name) for content in message["content"] - ], - role="user" if message["role"] == "user" else "model", - ) - for message in messages - ] + return contents + + def _has_s3_source(self, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]: """Format tool specs into Gemini tools. diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index ce0367bf5..943d8fb14 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -172,16 +172,25 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None """ formatted_messages: list[dict[str, Any]] formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + s3_source_warned = False for message in messages: contents = message["content"] + # Filter out S3 sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if self._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by LlamaAPI, skipping content block") + s3_source_warned = True + continue + filtered_contents.append(content) + formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = "" - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + formatted_contents = [self._format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) for content in contents @@ -206,6 +215,23 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + def _has_s3_source(self, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False + def format_request( self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index ca838f3d7..1228fd1a4 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -291,6 +291,7 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None) Formatted messages array compatible with llama.cpp. """ formatted_messages: list[dict[str, Any]] = [] + s3_source_warned = False # Add system prompt if provided if system_prompt: @@ -299,11 +300,19 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None) for message in messages: contents = message["content"] - formatted_contents = [ - self._format_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + # Filter out S3 sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if self._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by llama.cpp, skipping content block") + s3_source_warned = True + continue + filtered_contents.append(content) + + formatted_contents = [self._format_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_tool_call( { @@ -336,6 +345,23 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None) return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + def _has_s3_source(self, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False + def _format_request( self, messages: Messages, diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 4ec77ccfe..71a083f45 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -199,6 +199,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None A Mistral compatible messages array. """ formatted_messages: list[dict[str, Any]] = [] + s3_source_warned = False if system_prompt: formatted_messages.append({"role": "system", "content": system_prompt}) @@ -212,6 +213,13 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None tool_messages: list[dict[str, Any]] = [] for content in contents: + # Check for S3 sources and skip with warning + if self._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by Mistral, skipping content block") + s3_source_warned = True + continue + if "text" in content: formatted_content = self._format_request_message_content(content) if isinstance(formatted_content, str): @@ -236,6 +244,23 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None return formatted_messages + def _has_s3_source(self, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False + def format_request( self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 8d72aa534..069432827 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -159,13 +159,37 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None An Ollama compatible messages array. """ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] + s3_source_warned = False - return system_message + [ - formatted_message - for message in messages - for content in message["content"] - for formatted_message in self._format_request_message_contents(message["role"], content) - ] + formatted_messages = [] + for message in messages: + for content in message["content"]: + # Check for S3 sources and skip with warning + if self._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by Ollama, skipping content block") + s3_source_warned = True + continue + formatted_messages.extend(self._format_request_message_contents(message["role"], content)) + + return system_message + formatted_messages + + def _has_s3_source(self, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False def format_request( self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index d9266212b..c2360438f 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -328,6 +328,7 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic List of formatted messages. """ formatted_messages = [] + s3_source_warned = False for message in messages: contents = message["content"] @@ -338,11 +339,19 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic "reasoningContent is not supported in multi-turn conversations with the Chat Completions API." ) - formatted_contents = [ - cls.format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) - ] + # Filter out S3 sources and content blocks that shouldn't be formatted + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]): + continue + if cls._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by OpenAI, skipping content block") + s3_source_warned = True + continue + filtered_contents.append(content) + + formatted_contents = [cls.format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content ] @@ -369,6 +378,24 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic return formatted_messages + @classmethod + def _has_s3_source(cls, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False + @classmethod def format_request_messages( cls, diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index f306d649b..96d4eaac8 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -214,15 +214,28 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None """ formatted_messages: list[dict[str, Any]] formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + s3_source_warned = False for message in messages: contents = message["content"] + # Filter out S3 sources + filtered_contents = [] + for content in contents: + if self._has_s3_source(content): + if not s3_source_warned: + logger.warning("S3 sources are not supported by Writer, skipping content block") + s3_source_warned = True + continue + filtered_contents.append(content) + # Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}' if self.get_config().get("model_id", "") == "palmyra-x5": - formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents) + formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision( + filtered_contents + ) else: - formatted_contents = self._format_request_message_contents(contents) + formatted_contents = self._format_request_message_contents(filtered_contents) formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) @@ -245,6 +258,23 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + def _has_s3_source(self, content: ContentBlock) -> bool: + """Check if a content block contains an S3 source. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an S3 source, False otherwise. + """ + if "image" in content: + return "s3Location" in content["image"].get("source", {}) + if "document" in content: + return "s3Location" in content["document"].get("source", {}) + if "video" in content: + return "s3Location" in content["video"].get("source", {}) + return False + def format_request( self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 462d8af34..8cdf97b54 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -15,14 +15,34 @@ """Supported document formats.""" -class DocumentSource(TypedDict): +class S3Location(TypedDict, total=False): + """A storage location in an Amazon S3 bucket. + + Used by Bedrock to reference media files stored in S3 instead of passing raw bytes. + + - Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html + + Attributes: + uri: An object URI starting with `s3://`. + bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. + """ + + uri: str + bucketOwner: str + + +class DocumentSource(TypedDict, total=False): """Contains the content of a document. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the document. + s3Location: S3 location of the document (Bedrock only). """ bytes: bytes + s3Location: S3Location class DocumentContent(TypedDict, total=False): @@ -45,14 +65,18 @@ class DocumentContent(TypedDict, total=False): """Supported image formats.""" -class ImageSource(TypedDict): +class ImageSource(TypedDict, total=False): """Contains the content of an image. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the image. + s3Location: S3 location of the image (Bedrock only). """ bytes: bytes + s3Location: S3Location class ImageContent(TypedDict): @@ -71,14 +95,18 @@ class ImageContent(TypedDict): """Supported video formats.""" -class VideoSource(TypedDict): +class VideoSource(TypedDict, total=False): """Contains the content of a video. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the video. + s3Location: S3 location of the video (Bedrock only). """ bytes: bytes + s3Location: S3Location class VideoContent(TypedDict): diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 74bbb8d45..8d550564a 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -866,3 +866,93 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +def test_format_request_filters_s3_source_image(model, model_id, max_tokens, caplog): + """Test that images with S3 sources are filtered out with warning.""" + import logging + + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"s3Location": {"uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "look at this image"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "S3 sources are not supported by Anthropic" in caplog.text + + +def test_format_request_filters_s3_source_document(model, model_id, max_tokens, caplog): + """Test that documents with S3 sources are filtered out with warning.""" + import logging + + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"s3Location": {"uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "analyze this document"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "S3 sources are not supported by Anthropic" in caplog.text + + +def test_has_s3_source_detection(model): + """Test _has_s3_source correctly detects S3 sources in various content types.""" + # Image with S3 source + assert model._has_s3_source({"image": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}) is True + + # Image with bytes source + assert model._has_s3_source({"image": {"source": {"bytes": b"data"}}}) is False + + # Document with S3 source + assert model._has_s3_source({"document": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}) is True + + # Document with bytes source + assert model._has_s3_source({"document": {"source": {"bytes": b"data"}}}) is False + + # Video with S3 source + assert model._has_s3_source({"video": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}) is True + + # Video with bytes source + assert model._has_s3_source({"video": {"source": {"bytes": b"data"}}}) is False + + # Text content (no source) + assert model._has_s3_source({"text": "hello"}) is False + + # Tool use (no source) + assert model._has_s3_source({"toolUse": {"name": "test", "input": {}, "toolUseId": "123"}}) is False diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 833b14729..4d1e44cfc 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1788,7 +1788,7 @@ def test_format_request_filters_image_content_blocks(model, model_id): def test_format_request_filters_nested_image_s3_fields(model, model_id): - """Test that s3Location is filtered out and only bytes source is preserved.""" + """Test that s3Location is used when present, even if bytes is also in source.""" messages = [ { "role": "user", @@ -1798,7 +1798,7 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): "format": "png", "source": { "bytes": b"image_data", - "s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"}, + "s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "123456789012"}, }, } } @@ -1809,8 +1809,104 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): formatted_request = model._format_request(messages) image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + # s3Location should be used when present + assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png", "bucketOwner": "123456789012"}} + assert "bytes" not in image_source + + +def test_format_request_image_s3_location_only(model, model_id): + """Test that image with only s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": { + "s3Location": {"uri": "s3://my-bucket/image.png"}, + }, + } + } + ], + } + ] + + formatted_request = model._format_request(messages) + image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + + assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}} + + +def test_format_request_image_bytes_only(model, model_id): + """Test that image with only bytes source is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": b"image_data"}, + } + } + ], + } + ] + + formatted_request = model._format_request(messages) + image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + assert image_source == {"bytes": b"image_data"} - assert "s3Location" not in image_source + + +def test_format_request_document_s3_location(model, model_id): + """Test that document with s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"}, + }, + } + } + ], + } + ] + + formatted_request = model._format_request(messages) + document = formatted_request["messages"][0]["content"][0]["document"] + + assert document["source"] == {"s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"}} + + +def test_format_request_video_s3_location(model, model_id): + """Test that video with s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": { + "s3Location": {"uri": "s3://my-bucket/video.mp4"}, + }, + } + } + ], + } + ] + + formatted_request = model._format_request(messages) + video_source = formatted_request["messages"][0]["content"][0]["video"]["source"] + + assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}} def test_format_request_filters_document_content_blocks(model, model_id): diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 7c1d18998..9a6427653 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1246,3 +1246,90 @@ def test_init_with_both_client_and_client_args_raises_error(): with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with S3 sources are filtered out with warning.""" + import logging + + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"s3Location": {"uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "S3 sources are not supported by OpenAI" in caplog.text + + +def test_format_request_filters_s3_source_document(model, caplog): + """Test that documents with S3 sources are filtered out with warning.""" + import logging + + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"s3Location": {"uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "S3 sources are not supported by OpenAI" in caplog.text + + +def test_has_s3_source_detection(): + """Test _has_s3_source class method correctly detects S3 sources.""" + from strands.models.openai import OpenAIModel + + # Image with S3 source + assert OpenAIModel._has_s3_source({"image": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}) is True + + # Image with bytes source + assert OpenAIModel._has_s3_source({"image": {"source": {"bytes": b"data"}}}) is False + + # Document with S3 source + assert OpenAIModel._has_s3_source({"document": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}) is True + + # Document with bytes source + assert OpenAIModel._has_s3_source({"document": {"source": {"bytes": b"data"}}}) is False + + # Video with S3 source + assert OpenAIModel._has_s3_source({"video": {"source": {"s3Location": {"uri": "s3://bucket/key"}}}}) is True + + # Video with bytes source + assert OpenAIModel._has_s3_source({"video": {"source": {"bytes": b"data"}}}) is False + + # Text content (no source) + assert OpenAIModel._has_s3_source({"text": "hello"}) is False diff --git a/tests/strands/types/test_media.py b/tests/strands/types/test_media.py new file mode 100644 index 000000000..9c2e238f5 --- /dev/null +++ b/tests/strands/types/test_media.py @@ -0,0 +1,102 @@ +"""Tests for media type definitions.""" + +from strands.types.media import ( + DocumentFormat, + DocumentSource, + ImageFormat, + ImageSource, + S3Location, + VideoFormat, + VideoSource, +) + + +class TestS3Location: + """Tests for S3Location TypedDict.""" + + def test_s3_location_with_uri_only(self): + """Test S3Location with only uri field.""" + s3_loc: S3Location = {"uri": "s3://my-bucket/path/to/file.pdf"} + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert "bucketOwner" not in s3_loc + + def test_s3_location_with_bucket_owner(self): + """Test S3Location with both uri and bucketOwner fields.""" + s3_loc: S3Location = { + "uri": "s3://my-bucket/path/to/file.pdf", + "bucketOwner": "123456789012", + } + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert s3_loc["bucketOwner"] == "123456789012" + + +class TestDocumentSource: + """Tests for DocumentSource TypedDict.""" + + def test_document_source_with_bytes(self): + """Test DocumentSource with bytes content.""" + doc_source: DocumentSource = {"bytes": b"document content"} + + assert doc_source["bytes"] == b"document content" + assert "s3Location" not in doc_source + + def test_document_source_with_s3_location(self): + """Test DocumentSource with s3Location.""" + doc_source: DocumentSource = { + "s3Location": { + "uri": "s3://my-bucket/docs/report.pdf", + "bucketOwner": "123456789012", + } + } + + assert "bytes" not in doc_source + assert doc_source["s3Location"]["uri"] == "s3://my-bucket/docs/report.pdf" + assert doc_source["s3Location"]["bucketOwner"] == "123456789012" + + +class TestImageSource: + """Tests for ImageSource TypedDict.""" + + def test_image_source_with_bytes(self): + """Test ImageSource with bytes content.""" + img_source: ImageSource = {"bytes": b"image content"} + + assert img_source["bytes"] == b"image content" + assert "s3Location" not in img_source + + def test_image_source_with_s3_location(self): + """Test ImageSource with s3Location.""" + img_source: ImageSource = { + "s3Location": { + "uri": "s3://my-bucket/images/photo.png", + } + } + + assert "bytes" not in img_source + assert img_source["s3Location"]["uri"] == "s3://my-bucket/images/photo.png" + + +class TestVideoSource: + """Tests for VideoSource TypedDict.""" + + def test_video_source_with_bytes(self): + """Test VideoSource with bytes content.""" + vid_source: VideoSource = {"bytes": b"video content"} + + assert vid_source["bytes"] == b"video content" + assert "s3Location" not in vid_source + + def test_video_source_with_s3_location(self): + """Test VideoSource with s3Location.""" + vid_source: VideoSource = { + "s3Location": { + "uri": "s3://my-bucket/videos/clip.mp4", + "bucketOwner": "987654321098", + } + } + + assert "bytes" not in vid_source + assert vid_source["s3Location"]["uri"] == "s3://my-bucket/videos/clip.mp4" + assert vid_source["s3Location"]["bucketOwner"] == "987654321098"