Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
An Anthropic messages array.
"""
formatted_messages = []
s3_source_warned = False

for message in messages:
formatted_contents: list[dict[str, Any]] = []
Expand All @@ -189,13 +190,37 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
formatted_contents[-1]["cache_control"] = {"type": "ephemeral"}
continue

# Check for S3 sources in image, document, or video content
if self._has_s3_source(content):
if not s3_source_warned:
logger.warning("S3 sources are not supported by Anthropic, skipping content block")
s3_source_warned = True
continue

formatted_contents.append(self._format_request_message_content(content))

if formatted_contents:
formatted_messages.append({"content": formatted_contents, "role": message["role"]})

return formatted_messages

def _has_s3_source(self, content: ContentBlock) -> bool:
"""Check if a content block contains an S3 source.

Args:
content: Content block to check.

Returns:
True if the content block contains an S3 source, False otherwise.
"""
if "image" in content:
return "s3Location" in content["image"].get("source", {})
if "document" in content:
return "s3Location" in content["document"].get("source", {})
if "video" in content:
return "s3Location" in content["video"].get("source", {})
return False

def format_request(
self,
messages: Messages,
Expand Down
32 changes: 26 additions & 6 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "format" in document:
result["format"] = document["format"]

# Handle source
# Handle source - supports bytes or s3Location
if "source" in document:
result["source"] = {"bytes": document["source"]["bytes"]}
source = document["source"]
if "s3Location" in source:
s3_loc = source["s3Location"]
formatted_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
if "bucketOwner" in s3_loc:
formatted_s3["bucketOwner"] = s3_loc["bucketOwner"]
result["source"] = {"s3Location": formatted_s3}
elif "bytes" in source:
result["source"] = {"bytes": source["bytes"]}

# Handle optional fields
if "citations" in document and document["citations"] is not None:
Expand All @@ -462,8 +470,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "image" in content:
image = content["image"]
source = image["source"]
formatted_source = {}
if "bytes" in source:
formatted_source: dict[str, Any] = {}
if "s3Location" in source:
s3_loc = source["s3Location"]
formatted_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
if "bucketOwner" in s3_loc:
formatted_s3["bucketOwner"] = s3_loc["bucketOwner"]
formatted_source = {"s3Location": formatted_s3}
elif "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": image["format"], "source": formatted_source}
return {"image": result}
Expand Down Expand Up @@ -527,8 +541,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "video" in content:
video = content["video"]
source = video["source"]
formatted_source = {}
if "bytes" in source:
formatted_source: dict[str, Any] = {}
if "s3Location" in source:
s3_loc = source["s3Location"]
formatted_s3: dict[str, Any] = {"uri": s3_loc["uri"]}
if "bucketOwner" in s3_loc:
formatted_s3["bucketOwner"] = s3_loc["bucketOwner"]
formatted_source = {"s3Location": formatted_s3}
elif "bytes" in source:
formatted_source = {"bytes": source["bytes"]}
result = {"format": video["format"], "source": formatted_source}
return {"video": result}
Expand Down
48 changes: 39 additions & 9 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,46 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten
# Gemini FunctionResponses are constructed from tool result blocks. Function name is required but is not
# available in tool result blocks, hence the mapping.
tool_use_id_to_name: dict[str, str] = {}
s3_source_warned = False

contents = []
for message in messages:
parts = []
for content in message["content"]:
# Check for S3 sources and skip with warning
if self._has_s3_source(content):
if not s3_source_warned:
logger.warning("S3 sources are not supported by Gemini, skipping content block")
s3_source_warned = True
continue
parts.append(self._format_request_content_part(content, tool_use_id_to_name))

if parts:
contents.append(
genai.types.Content(
parts=parts,
role="user" if message["role"] == "user" else "model",
)
)

return [
genai.types.Content(
parts=[
self._format_request_content_part(content, tool_use_id_to_name) for content in message["content"]
],
role="user" if message["role"] == "user" else "model",
)
for message in messages
]
return contents

def _has_s3_source(self, content: ContentBlock) -> bool:
"""Check if a content block contains an S3 source.

Args:
content: Content block to check.

Returns:
True if the content block contains an S3 source, False otherwise.
"""
if "image" in content:
return "s3Location" in content["image"].get("source", {})
if "document" in content:
return "s3Location" in content["document"].get("source", {})
if "video" in content:
return "s3Location" in content["video"].get("source", {})
return False

def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]:
"""Format tool specs into Gemini tools.
Expand Down
36 changes: 31 additions & 5 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,25 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None
"""
formatted_messages: list[dict[str, Any]]
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
s3_source_warned = False

for message in messages:
contents = message["content"]

# Filter out S3 sources and unsupported block types
filtered_contents = []
for content in contents:
if any(block_type in content for block_type in ["toolResult", "toolUse"]):
continue
if self._has_s3_source(content):
if not s3_source_warned:
logger.warning("S3 sources are not supported by LlamaAPI, skipping content block")
s3_source_warned = True
continue
filtered_contents.append(content)

formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = ""
formatted_contents = [
self._format_request_message_content(content)
for content in contents
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
]
formatted_contents = [self._format_request_message_content(content) for content in filtered_contents]
formatted_tool_calls = [
self._format_request_message_tool_call(content["toolUse"])
for content in contents
Expand All @@ -206,6 +215,23 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None

return [message for message in formatted_messages if message["content"] or "tool_calls" in message]

def _has_s3_source(self, content: ContentBlock) -> bool:
"""Check if a content block contains an S3 source.

Args:
content: Content block to check.

Returns:
True if the content block contains an S3 source, False otherwise.
"""
if "image" in content:
return "s3Location" in content["image"].get("source", {})
if "document" in content:
return "s3Location" in content["document"].get("source", {})
if "video" in content:
return "s3Location" in content["video"].get("source", {})
return False

def format_request(
self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None
) -> dict[str, Any]:
Expand Down
36 changes: 31 additions & 5 deletions src/strands/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None)
Formatted messages array compatible with llama.cpp.
"""
formatted_messages: list[dict[str, Any]] = []
s3_source_warned = False

# Add system prompt if provided
if system_prompt:
Expand All @@ -299,11 +300,19 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None)
for message in messages:
contents = message["content"]

formatted_contents = [
self._format_message_content(content)
for content in contents
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
]
# Filter out S3 sources and unsupported block types
filtered_contents = []
for content in contents:
if any(block_type in content for block_type in ["toolResult", "toolUse"]):
continue
if self._has_s3_source(content):
if not s3_source_warned:
logger.warning("S3 sources are not supported by llama.cpp, skipping content block")
s3_source_warned = True
continue
filtered_contents.append(content)

formatted_contents = [self._format_message_content(content) for content in filtered_contents]
formatted_tool_calls = [
self._format_tool_call(
{
Expand Down Expand Up @@ -336,6 +345,23 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None)

return [message for message in formatted_messages if message["content"] or "tool_calls" in message]

def _has_s3_source(self, content: ContentBlock) -> bool:
"""Check if a content block contains an S3 source.

Args:
content: Content block to check.

Returns:
True if the content block contains an S3 source, False otherwise.
"""
if "image" in content:
return "s3Location" in content["image"].get("source", {})
if "document" in content:
return "s3Location" in content["document"].get("source", {})
if "video" in content:
return "s3Location" in content["video"].get("source", {})
return False

def _format_request(
self,
messages: Messages,
Expand Down
25 changes: 25 additions & 0 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None
A Mistral compatible messages array.
"""
formatted_messages: list[dict[str, Any]] = []
s3_source_warned = False

if system_prompt:
formatted_messages.append({"role": "system", "content": system_prompt})
Expand All @@ -212,6 +213,13 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None
tool_messages: list[dict[str, Any]] = []

for content in contents:
# Check for S3 sources and skip with warning
if self._has_s3_source(content):
if not s3_source_warned:
logger.warning("S3 sources are not supported by Mistral, skipping content block")
s3_source_warned = True
continue

if "text" in content:
formatted_content = self._format_request_message_content(content)
if isinstance(formatted_content, str):
Expand All @@ -236,6 +244,23 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None

return formatted_messages

def _has_s3_source(self, content: ContentBlock) -> bool:
"""Check if a content block contains an S3 source.

Args:
content: Content block to check.

Returns:
True if the content block contains an S3 source, False otherwise.
"""
if "image" in content:
return "s3Location" in content["image"].get("source", {})
if "document" in content:
return "s3Location" in content["document"].get("source", {})
if "video" in content:
return "s3Location" in content["video"].get("source", {})
return False

def format_request(
self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None
) -> dict[str, Any]:
Expand Down
36 changes: 30 additions & 6 deletions src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,37 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None
An Ollama compatible messages array.
"""
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
s3_source_warned = False

return system_message + [
formatted_message
for message in messages
for content in message["content"]
for formatted_message in self._format_request_message_contents(message["role"], content)
]
formatted_messages = []
for message in messages:
for content in message["content"]:
# Check for S3 sources and skip with warning
if self._has_s3_source(content):
if not s3_source_warned:
logger.warning("S3 sources are not supported by Ollama, skipping content block")
s3_source_warned = True
continue
formatted_messages.extend(self._format_request_message_contents(message["role"], content))

return system_message + formatted_messages

def _has_s3_source(self, content: ContentBlock) -> bool:
"""Check if a content block contains an S3 source.

Args:
content: Content block to check.

Returns:
True if the content block contains an S3 source, False otherwise.
"""
if "image" in content:
return "s3Location" in content["image"].get("source", {})
if "document" in content:
return "s3Location" in content["document"].get("source", {})
if "video" in content:
return "s3Location" in content["video"].get("source", {})
return False

def format_request(
self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None
Expand Down
Loading