Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import json
import logging
import os
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
import threading
from typing import Any, AsyncGenerator, Generator, Iterable, Literal, Optional, Type, TypeVar, Union

import boto3
from botocore.config import Config as BotocoreConfig
Expand Down Expand Up @@ -376,40 +377,41 @@ async def stream(
ModelThrottledException: If the model service is throttling requests.
"""

def callback(event: Optional[StreamEvent] = None) -> None:
loop.call_soon_threadsafe(queue.put_nowait, event)
if event is None:
return
def target() -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this be private?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diff cutoff above but this is a nested function of stream and so is private by that nature.

try:
for event in self._stream(messages, tool_specs, system_prompt):
loop.call_soon_threadsafe(queue.put_nowait, event)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, None)

loop = asyncio.get_event_loop()
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
queue: asyncio.Queue[StreamEvent | Exception | None] = asyncio.Queue()

thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asyncio.to_thread uses a thread pool executor that is not directly configurable. Consequently, we have to create our own threads in order to control the daemon mode.

task = asyncio.create_task(thread)
threading.Thread(target=target, daemon=True).start()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One down side here is that we create a new thread on every stream invocation rather than pull from a pool of existing threads. In this context however, the impact should be negligible as bedrock.converse_stream is a long running call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not as much concerned about the cost (time/space) of creating the new thread. But for highly concurrent use cases this concerns me because of the failure mode.

In the previous case it seems we would have waited until a thread became in the ThreadPoolExecutor allowing the request to succeed. In the current case we are introducing a new thread exhaustion failure mode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really good call out and should be thought through carefully. I'll note that I avoided use of a ThreadPoolExecutor here because there is no easy way to configure daemon mode. We would have to derive our own version with some overrides.


while True:
event = await queue.get()
if event is None:
break
if isinstance(event, Exception):
raise event
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python threads do not propagate raised exceptions to the main thread automatically. Consequently, we need to explicitly communicate exceptions which we can do so through our existing queue.


yield event

await task

def _stream(
self,
callback: Callable[..., None],
messages: Messages,
tool_specs: Optional[list[ToolSpec]] = None,
system_prompt: Optional[str] = None,
) -> None:
) -> Generator[StreamEvent, None, None]:
"""Stream conversation with the Bedrock model.

This method operates in a separate thread to avoid blocking the async event loop with the call to
Bedrock's converse_stream.

Args:
callback: Function to send events to the main thread.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
Expand All @@ -426,7 +428,6 @@ def _stream(
streaming = self.config.get("streaming", True)

try:
logger.debug("got response from model")
if streaming:
response = self.client.converse_stream(**request)
for chunk in response["stream"]:
Expand All @@ -438,22 +439,22 @@ def _stream(
guardrail_data = chunk["metadata"]["trace"]["guardrail"]
if self._has_blocked_guardrail(guardrail_data):
for event in self._generate_redaction_events():
callback(event)
yield event

callback(chunk)
yield chunk

else:
response = self.client.converse(**request)
for event in self._convert_non_streaming_to_streaming(response):
callback(event)
yield event

if (
"trace" in response
and "guardrail" in response["trace"]
and self._has_blocked_guardrail(response["trace"]["guardrail"])
):
for event in self._generate_redaction_events():
callback(event)
yield event

except ClientError as e:
error_message = str(e)
Expand Down Expand Up @@ -494,7 +495,6 @@ def _stream(
raise e

finally:
callback()
logger.debug("finished streaming response from model")

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
Expand Down
1 change: 0 additions & 1 deletion tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,6 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
assert "formatting request" in log_text
assert "request=<" in log_text
assert "invoking model" in log_text
assert "got response from model" in log_text
assert "finished streaming response from model" in log_text


Expand Down
Loading