-
Notifications
You must be signed in to change notification settings - Fork 608
models - bedrock - daemon thread #593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,8 @@ | |
| import json | ||
| import logging | ||
| import os | ||
| from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union | ||
| import threading | ||
| from typing import Any, AsyncGenerator, Generator, Iterable, Literal, Optional, Type, TypeVar, Union | ||
|
|
||
| import boto3 | ||
| from botocore.config import Config as BotocoreConfig | ||
|
|
@@ -376,40 +377,41 @@ async def stream( | |
| ModelThrottledException: If the model service is throttling requests. | ||
| """ | ||
|
|
||
| def callback(event: Optional[StreamEvent] = None) -> None: | ||
| loop.call_soon_threadsafe(queue.put_nowait, event) | ||
| if event is None: | ||
| return | ||
| def target() -> None: | ||
| try: | ||
| for event in self._stream(messages, tool_specs, system_prompt): | ||
| loop.call_soon_threadsafe(queue.put_nowait, event) | ||
| except Exception as e: | ||
| loop.call_soon_threadsafe(queue.put_nowait, e) | ||
| finally: | ||
| loop.call_soon_threadsafe(queue.put_nowait, None) | ||
|
|
||
| loop = asyncio.get_event_loop() | ||
| queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() | ||
| queue: asyncio.Queue[StreamEvent | Exception | None] = asyncio.Queue() | ||
|
|
||
| thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| task = asyncio.create_task(thread) | ||
| threading.Thread(target=target, daemon=True).start() | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One down side here is that we create a new thread on every stream invocation rather than pull from a pool of existing threads. In this context however, the impact should be negligible as
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not as much concerned about the cost (time/space) of creating the new thread. But for highly concurrent use cases this concerns me because of the failure mode. In the previous case it seems we would have waited until a thread became in the ThreadPoolExecutor allowing the request to succeed. In the current case we are introducing a new thread exhaustion failure mode.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a really good call out and should be thought through carefully. I'll note that I avoided use of a ThreadPoolExecutor here because there is no easy way to configure daemon mode. We would have to derive our own version with some overrides. |
||
|
|
||
| while True: | ||
| event = await queue.get() | ||
| if event is None: | ||
| break | ||
| if isinstance(event, Exception): | ||
| raise event | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python threads do not propagate raised exceptions to the main thread automatically. Consequently, we need to explicitly communicate exceptions which we can do so through our existing queue. |
||
|
|
||
| yield event | ||
|
|
||
| await task | ||
|
|
||
| def _stream( | ||
| self, | ||
| callback: Callable[..., None], | ||
| messages: Messages, | ||
| tool_specs: Optional[list[ToolSpec]] = None, | ||
| system_prompt: Optional[str] = None, | ||
| ) -> None: | ||
| ) -> Generator[StreamEvent, None, None]: | ||
| """Stream conversation with the Bedrock model. | ||
|
|
||
| This method operates in a separate thread to avoid blocking the async event loop with the call to | ||
| Bedrock's converse_stream. | ||
|
|
||
| Args: | ||
| callback: Function to send events to the main thread. | ||
| messages: List of message objects to be processed by the model. | ||
| tool_specs: List of tool specifications to make available to the model. | ||
| system_prompt: System prompt to provide context to the model. | ||
|
|
@@ -426,7 +428,6 @@ def _stream( | |
| streaming = self.config.get("streaming", True) | ||
|
|
||
| try: | ||
| logger.debug("got response from model") | ||
| if streaming: | ||
| response = self.client.converse_stream(**request) | ||
| for chunk in response["stream"]: | ||
|
|
@@ -438,22 +439,22 @@ def _stream( | |
| guardrail_data = chunk["metadata"]["trace"]["guardrail"] | ||
| if self._has_blocked_guardrail(guardrail_data): | ||
| for event in self._generate_redaction_events(): | ||
| callback(event) | ||
| yield event | ||
|
|
||
| callback(chunk) | ||
| yield chunk | ||
|
|
||
| else: | ||
| response = self.client.converse(**request) | ||
| for event in self._convert_non_streaming_to_streaming(response): | ||
| callback(event) | ||
| yield event | ||
|
|
||
| if ( | ||
| "trace" in response | ||
| and "guardrail" in response["trace"] | ||
| and self._has_blocked_guardrail(response["trace"]["guardrail"]) | ||
| ): | ||
| for event in self._generate_redaction_events(): | ||
| callback(event) | ||
| yield event | ||
|
|
||
| except ClientError as e: | ||
| error_message = str(e) | ||
|
|
@@ -494,7 +495,6 @@ def _stream( | |
| raise e | ||
|
|
||
| finally: | ||
| callback() | ||
| logger.debug("finished streaming response from model") | ||
|
|
||
| def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should this be private?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diff cutoff above but this is a nested function of
streamand so is private by that nature.