From 5b243c1f5e51183386220f1ec05ee57e8281af39 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 1 Aug 2025 14:45:32 +0000 Subject: [PATCH] models - bedrock - daemon thread --- src/strands/models/bedrock.py | 38 ++++++++++++++-------------- tests/strands/models/test_bedrock.py | 1 - 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4ea1453a4..82e53a639 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -7,7 +7,8 @@ import json import logging import os -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +import threading +from typing import Any, AsyncGenerator, Generator, Iterable, Literal, Optional, Type, TypeVar, Union import boto3 from botocore.config import Config as BotocoreConfig @@ -376,40 +377,41 @@ async def stream( ModelThrottledException: If the model service is throttling requests. """ - def callback(event: Optional[StreamEvent] = None) -> None: - loop.call_soon_threadsafe(queue.put_nowait, event) - if event is None: - return + def target() -> None: + try: + for event in self._stream(messages, tool_specs, system_prompt): + loop.call_soon_threadsafe(queue.put_nowait, event) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) loop = asyncio.get_event_loop() - queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() + queue: asyncio.Queue[StreamEvent | Exception | None] = asyncio.Queue() - thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt) - task = asyncio.create_task(thread) + threading.Thread(target=target, daemon=True).start() while True: event = await queue.get() if event is None: break + if isinstance(event, Exception): + raise event yield event - await task - def _stream( self, - callback: Callable[..., None], messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, - ) -> None: + ) -> Generator[StreamEvent, None, None]: """Stream conversation with the Bedrock model. This method operates in a separate thread to avoid blocking the async event loop with the call to Bedrock's converse_stream. Args: - callback: Function to send events to the main thread. messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. @@ -426,7 +428,6 @@ def _stream( streaming = self.config.get("streaming", True) try: - logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) for chunk in response["stream"]: @@ -438,14 +439,14 @@ def _stream( guardrail_data = chunk["metadata"]["trace"]["guardrail"] if self._has_blocked_guardrail(guardrail_data): for event in self._generate_redaction_events(): - callback(event) + yield event - callback(chunk) + yield chunk else: response = self.client.converse(**request) for event in self._convert_non_streaming_to_streaming(response): - callback(event) + yield event if ( "trace" in response @@ -453,7 +454,7 @@ def _stream( and self._has_blocked_guardrail(response["trace"]["guardrail"]) ): for event in self._generate_redaction_events(): - callback(event) + yield event except ClientError as e: error_message = str(e) @@ -494,7 +495,6 @@ def _stream( raise e finally: - callback() logger.debug("finished streaming response from model") def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 0a2846adf..b8d7ce4e4 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1197,7 +1197,6 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "formatting request" in log_text assert "request=<" in log_text assert "invoking model" in log_text - assert "got response from model" in log_text assert "finished streaming response from model" in log_text