diff --git a/flo_ai/flo_ai/llm/__init__.py b/flo_ai/flo_ai/llm/__init__.py index 934e4da4..1b8d2c42 100644 --- a/flo_ai/flo_ai/llm/__init__.py +++ b/flo_ai/flo_ai/llm/__init__.py @@ -6,6 +6,7 @@ from .openai_vllm import OpenAIVLLM from .vertexai_llm import VertexAI from .rootflo_llm import RootFloLLM +from .aws_bedrock_llm import AWSBedrock __all__ = [ 'BaseLLM', @@ -16,4 +17,5 @@ 'OpenAIVLLM', 'VertexAI', 'RootFloLLM', + 'AWSBedrock', ] diff --git a/flo_ai/flo_ai/llm/aws_bedrock_llm.py b/flo_ai/flo_ai/llm/aws_bedrock_llm.py new file mode 100644 index 00000000..741ec9c3 --- /dev/null +++ b/flo_ai/flo_ai/llm/aws_bedrock_llm.py @@ -0,0 +1,285 @@ +import json +import re +from typing import Dict, Any, List, AsyncIterator, Optional +import boto3 +import asyncio +from .base_llm import BaseLLM +from flo_ai.models.chat_message import ImageMessageContent +from flo_ai.tool.base_tool import Tool +from flo_ai.telemetry.instrumentation import ( + trace_llm_call, + trace_llm_stream, + llm_metrics, + add_span_attributes, +) +from flo_ai.telemetry import get_tracer +from flo_ai.utils.logger import logger +from opentelemetry import trace + + +class AWSBedrock(BaseLLM): # Only openai compatible for now + def __init__( + self, + model: str = 'openai.gpt-oss-20b-1:0', + temperature: float = 0.7, + **kwargs, + ): + super().__init__(model=model, temperature=temperature, **kwargs) + self.boto_client = boto3.client('bedrock-runtime') + self.model = model + self.kwargs = kwargs + + @staticmethod + def _strip_reasoning(text: str) -> str: + return re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + + def _convert_messages( + self, messages: list[dict], output_schema: dict | None = None + ) -> list[dict]: + result = [] + + if output_schema: + result.append( + { + 'role': 'system', + 'content': f'Provide output in the following JSON schema:\n{json.dumps(output_schema, indent=2)}', + } + ) + + for msg in messages: + if msg['role'] == 'function': + result.append( + { + 'role': 'tool', + 'tool_call_id': msg.get('tool_use_id', 'unknown'), + 'content': msg['content'], + 'name': msg.get('name', ''), + } + ) + else: + result.append(msg) + + return result + + @trace_llm_call(provider='bedrock') + async def generate( + self, + messages: list[dict], + functions: Optional[List[Dict[str, Any]]] = None, + output_schema: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + converted = self._convert_messages(messages, output_schema) + + request_options = {**self.kwargs, **kwargs} + request_body: Dict[str, Any] = { + 'model': self.model, + 'messages': converted, + 'temperature': self.temperature, + } + if 'max_tokens' in request_options: + request_body['max_completion_tokens'] = request_options['max_tokens'] + if functions: + request_body['tools'] = functions + + response = await asyncio.to_thread( + self.boto_client.invoke_model, + modelId=self.model, + body=json.dumps(request_body), + ) + response_body = json.loads(response['body'].read().decode('utf-8')) + + usage = response_body.get('usage', {}) + if usage: + llm_metrics.record_tokens( + total_tokens=usage.get('total_tokens', 0), + prompt_tokens=usage.get('prompt_tokens', 0), + completion_tokens=usage.get('completion_tokens', 0), + model=self.model, + provider='bedrock', + ) + tracer = get_tracer() + if tracer: + add_span_attributes( + trace.get_current_span(), + { + 'llm.tokens.prompt': usage.get('prompt_tokens', 0), + 'llm.tokens.completion': usage.get('completion_tokens', 0), + 'llm.tokens.total': usage.get('total_tokens', 0), + }, + ) + + choices = response_body.get('choices', []) + if not choices: + return {'content': '', 'raw_message': response_body} + + message = choices[0].get('message', {}) + if 'content' in message and message['content']: + message['content'] = self._strip_reasoning(message['content']) + text_content = message.get('content', '') or '' + tool_call = None + + tool_calls = message.get('tool_calls', []) + if tool_calls: + tc = tool_calls[0] + tool_call = { + 'name': tc['function']['name'], + 'arguments': tc['function']['arguments'], + 'id': tc['id'], + } + + if tool_call: + return { + 'content': text_content, + 'function_call': tool_call, + 'raw_message': message, + } + return {'content': text_content, 'raw_message': message} + + @trace_llm_stream(provider='bedrock') + async def stream( + self, + messages: List[Dict[str, Any]], + functions: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + converted = self._convert_messages(messages) + + request_options = {**self.kwargs, **kwargs} + request_body: Dict[str, Any] = { + 'model': self.model, + 'messages': converted, + 'temperature': self.temperature, + 'stream': True, + } + if 'max_tokens' in request_options: + request_body['max_completion_tokens'] = request_options['max_tokens'] + if functions: + request_body['tools'] = functions + + response = await asyncio.to_thread( + self.boto_client.invoke_model_with_response_stream, + modelId=self.model, + body=json.dumps(request_body), + ) + + queue: asyncio.Queue = asyncio.Queue() + loop = asyncio.get_running_loop() + + def _iter_events(): + try: + for event in response['body']: + chunk_bytes = event.get('chunk', {}).get('bytes', b'') + if chunk_bytes: + loop.call_soon_threadsafe(queue.put_nowait, chunk_bytes) + except Exception as exc: + loop.call_soon_threadsafe(queue.put_nowait, exc) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel + + loop.run_in_executor(None, _iter_events) + + while True: + chunk_bytes = await queue.get() + if isinstance(chunk_bytes, Exception): + raise chunk_bytes + if chunk_bytes is None: + break + text = chunk_bytes.decode('utf-8').strip() + content = None + try: + data = json.loads(text) + content = data.get('choices', [{}])[0].get('delta', {}).get('content') + except json.JSONDecodeError: + # Not valid JSON, try SSE format below + for line in text.split('\n'): + line = line.strip() + if line.startswith('data: ') and line != 'data: [DONE]': + try: + data = json.loads(line[6:]) + content = ( + data.get('choices', [{}])[0] + .get('delta', {}) + .get('content') + ) + if content: + clean = self._strip_reasoning(content) + if clean: + yield {'content': clean} + content = None + except json.JSONDecodeError: + logger.debug('Skipping malformed SSE line: %s', line) + if content: + clean = self._strip_reasoning(content) + if clean: + yield {'content': clean} + + def get_message_content(self, response: Dict[str, Any]) -> str: + content = ( + response.get('content', '') if isinstance(response, dict) else str(response) + ) + return self._strip_reasoning(content) + + def format_tool_for_llm(self, tool: 'Tool') -> Dict[str, Any]: + return { + 'type': 'function', + 'function': { + 'name': tool.name, + 'description': tool.description, + 'parameters': { + 'type': 'object', + 'properties': { + name: { + 'type': info.get('type', 'string'), + 'description': info.get('description', ''), + **( + {'items': info['items']} + if info.get('type') == 'array' and 'items' in info + else {} + ), + } + for name, info in tool.parameters.items() + }, + 'required': [ + name + for name, info in tool.parameters.items() + if info.get('required', True) + ], + }, + }, + } + + def format_tools_for_llm(self, tools: List['Tool']) -> List[Dict[str, Any]]: + return [self.format_tool_for_llm(tool) for tool in tools] + + def format_image_in_message(self, image: ImageMessageContent) -> dict: + if image.base64: + return { + 'type': 'image_url', + 'image_url': { + 'url': f'data:{image.mime_type or "image/jpeg"};base64,{image.base64}' + }, + } + raise NotImplementedError( + 'AWS Bedrock LLM requires image base64 data to format image content.' + ) + + def get_assistant_message_for_tool_call( + self, response: Dict[str, Any] + ) -> Optional[Any]: + if isinstance(response, dict) and 'raw_message' in response: + return response['raw_message'] + return None + + def get_tool_use_id(self, function_call: Dict[str, Any]) -> Optional[str]: + return function_call.get('id') + + def format_function_result_message( + self, function_name: str, content: str, tool_use_id: Optional[str] = None + ) -> Dict[str, Any]: + return { + 'role': 'tool', + 'tool_call_id': tool_use_id or 'unknown', + 'content': content, + 'name': function_name, + } diff --git a/flo_ai/pyproject.toml b/flo_ai/pyproject.toml index 2b976988..7e2ff661 100644 --- a/flo_ai/pyproject.toml +++ b/flo_ai/pyproject.toml @@ -10,6 +10,7 @@ license = "MIT" dependencies = [ "aiohttp>=3.12.14,<4", "anthropic>=0.57.1,<0.58", + "boto3>=1.36.1,<2", "chardet>=3.dev0,<4.dev0", "cryptography>=46.0.3", "google-cloud-aiplatform>=1.109.0,<2", @@ -35,8 +36,6 @@ vizualize = [ [dependency-groups] dev = [ - "boto3>=1.36.1,<2", - "botocore>=1.36.1,<2", "db-sqlite3>=0.0.1,<0.0.2", "ipykernel>=6.29.5,<7", "peewee>=3.17.6,<4", diff --git a/flo_ai/uv.lock b/flo_ai/uv.lock index f9b039b9..36b047ce 100644 --- a/flo_ai/uv.lock +++ b/flo_ai/uv.lock @@ -902,6 +902,7 @@ source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "anthropic" }, + { name = "boto3" }, { name = "chardet" }, { name = "cryptography" }, { name = "google-cloud-aiplatform" }, @@ -928,8 +929,6 @@ vizualize = [ [package.dev-dependencies] dev = [ - { name = "boto3" }, - { name = "botocore" }, { name = "db-sqlite3" }, { name = "ipykernel" }, { name = "peewee" }, @@ -947,6 +946,7 @@ dev = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.12.14,<4" }, { name = "anthropic", specifier = ">=0.57.1,<0.58" }, + { name = "boto3", specifier = ">=1.36.1,<2" }, { name = "chardet", specifier = ">=3.dev0,<4.dev0" }, { name = "cryptography", specifier = ">=46.0.3" }, { name = "google-cloud-aiplatform", specifier = ">=1.109.0,<2" }, @@ -969,8 +969,6 @@ provides-extras = ["vizualize"] [package.metadata.requires-dev] dev = [ - { name = "boto3", specifier = ">=1.36.1,<2" }, - { name = "botocore", specifier = ">=1.36.1,<2" }, { name = "db-sqlite3", specifier = ">=0.0.1,<0.0.2" }, { name = "ipykernel", specifier = ">=6.29.5,<7" }, { name = "peewee", specifier = ">=3.17.6,<4" },