From dc4b818979b526b0d925c871257816caff1ab642 Mon Sep 17 00:00:00 2001 From: x42en Date: Wed, 21 May 2025 19:09:59 +0200 Subject: [PATCH 1/2] feat: Add support for Ollama API --- python_a2a/__init__.py | 253 +++++----- python_a2a/client/__init__.py | 19 +- python_a2a/client/llm/__init__.py | 10 +- python_a2a/client/llm/ollama.py | 86 ++++ python_a2a/client/llm/openai.py | 346 +++++++------ python_a2a/client/streaming.py | 793 ++++++++++++++++-------------- python_a2a/server/__init__.py | 14 +- python_a2a/server/llm/__init__.py | 8 +- python_a2a/server/llm/ollama.py | 105 ++++ python_a2a/server/llm/openai.py | 519 ++++++++++--------- 10 files changed, 1254 insertions(+), 899 deletions(-) create mode 100644 python_a2a/client/llm/ollama.py create mode 100644 python_a2a/server/llm/ollama.py diff --git a/python_a2a/__init__.py b/python_a2a/__init__.py index 4924650..3e41a74 100644 --- a/python_a2a/__init__.py +++ b/python_a2a/__init__.py @@ -21,7 +21,7 @@ A2AValidationError, A2AAuthenticationError, A2AConfigurationError, - A2AStreamingError + A2AStreamingError, ) # All core models - these should be available with basic install @@ -35,7 +35,7 @@ FunctionCallContent, FunctionResponseContent, ErrorContent, - Metadata + Metadata, ) from .models.agent import AgentCard, AgentSkill from .models.task import Task, TaskStatus, TaskState @@ -58,7 +58,7 @@ run_registry, DiscoveryClient, enable_discovery, - RegistryAgent + RegistryAgent, ) # Utility functions @@ -66,13 +66,13 @@ format_message_as_text, format_conversation_as_text, pretty_print_message, - pretty_print_conversation + pretty_print_conversation, ) from .utils.validation import ( validate_message, validate_conversation, is_valid_message, - is_valid_conversation + is_valid_conversation, ) from .utils.conversion import ( create_text_message, @@ -80,7 +80,7 @@ create_function_response, create_error_message, format_function_params, - conversation_to_messages + conversation_to_messages, ) from .utils.decorators import skill, agent @@ -96,7 +96,7 @@ ConditionStep, ParallelStep, ParallelBuilder, - StepType + StepType, ) # MCP integration @@ -106,7 +106,7 @@ MCPConnectionError, MCPTimeoutError, MCPToolError, - MCPTools + MCPTools, ) from .mcp.agent import MCPEnabledAgent from .mcp.fastmcp import ( @@ -116,12 +116,9 @@ error_response, image_response, multi_content_response, - ContentType as MCPContentType -) -from .mcp.integration import ( - FastMCPAgent, - A2AMCPAgent + ContentType as MCPContentType, ) +from .mcp.integration import FastMCPAgent, A2AMCPAgent from .mcp.proxy import create_proxy_server from .mcp.transport import create_fastapi_app @@ -130,7 +127,7 @@ to_a2a_server, to_langchain_agent, to_mcp_server, - to_langchain_tool + to_langchain_tool, ) from .langchain.exceptions import ( LangChainIntegrationError, @@ -138,15 +135,25 @@ LangChainToolConversionError, MCPToolConversionError, LangChainAgentConversionError, - A2AAgentConversionError + A2AAgentConversionError, +) + +HAS_LANGCHAIN = ( + importlib.util.find_spec("langchain") is not None + or importlib.util.find_spec("langchain_core") is not None ) -HAS_LANGCHAIN = importlib.util.find_spec("langchain") is not None or importlib.util.find_spec("langchain_core") is not None # Optional integration with LLM providers # These might not be available if the specific provider packages are not installed try: - from .client.llm import OpenAIA2AClient, AnthropicA2AClient - from .server.llm import OpenAIA2AServer, AnthropicA2AServer, BedrockA2AServer + from .client.llm import OpenAIA2AClient, OllamaA2AClient, AnthropicA2AClient + from .server.llm import ( + OpenAIA2AServer, + OllamaA2AServer, + AnthropicA2AServer, + BedrockA2AServer, + ) + HAS_LLM_CLIENTS = True HAS_LLM_SERVERS = True except ImportError: @@ -156,6 +163,7 @@ # Optional doc generation try: from .docs import generate_a2a_docs, generate_html_docs + HAS_DOCS = True except ImportError: HAS_DOCS = False @@ -163,6 +171,7 @@ # Optional CLI try: from .cli import main as cli_main + HAS_CLI = True except ImportError: HAS_CLI = False @@ -170,6 +179,7 @@ # Agent Flow import - optional but integrated by default try: from .agent_flow import models, engine, server, storage + HAS_AGENT_FLOW_IMPORT = True except ImportError: HAS_AGENT_FLOW_IMPORT = False @@ -193,130 +203,123 @@ # Define __all__ for explicit exports __all__ = [ # Version - '__version__', - + "__version__", # Exceptions - 'A2AError', - 'A2AImportError', - 'A2AConnectionError', - 'A2AResponseError', - 'A2ARequestError', - 'A2AValidationError', - 'A2AAuthenticationError', - 'A2AConfigurationError', - 'A2AStreamingError', - + "A2AError", + "A2AImportError", + "A2AConnectionError", + "A2AResponseError", + "A2ARequestError", + "A2AValidationError", + "A2AAuthenticationError", + "A2AConfigurationError", + "A2AStreamingError", # Models - 'BaseModel', - 'Message', - 'MessageRole', - 'Conversation', - 'ContentType', - 'TextContent', - 'FunctionParameter', - 'FunctionCallContent', - 'FunctionResponseContent', - 'ErrorContent', - 'Metadata', - 'AgentCard', - 'AgentSkill', - 'Task', - 'TaskStatus', - 'TaskState', - + "BaseModel", + "Message", + "MessageRole", + "Conversation", + "ContentType", + "TextContent", + "FunctionParameter", + "FunctionCallContent", + "FunctionResponseContent", + "ErrorContent", + "Metadata", + "AgentCard", + "AgentSkill", + "Task", + "TaskStatus", + "TaskState", # Client - 'BaseA2AClient', - 'A2AClient', - 'AgentNetwork', - 'AIAgentRouter', - 'StreamingClient', - + "BaseA2AClient", + "A2AClient", + "AgentNetwork", + "AIAgentRouter", + "StreamingClient", # Server - 'BaseA2AServer', - 'A2AServer', - 'run_server', - + "BaseA2AServer", + "A2AServer", + "run_server", # Discovery - 'AgentRegistry', - 'run_registry', - 'DiscoveryClient', - 'enable_discovery', - 'RegistryAgent', - + "AgentRegistry", + "run_registry", + "DiscoveryClient", + "enable_discovery", + "RegistryAgent", # Utilities - 'format_message_as_text', - 'format_conversation_as_text', - 'pretty_print_message', - 'pretty_print_conversation', - 'validate_message', - 'validate_conversation', - 'is_valid_message', - 'is_valid_conversation', - 'create_text_message', - 'create_function_call', - 'create_function_response', - 'create_error_message', - 'format_function_params', - 'conversation_to_messages', - 'skill', - 'agent', - + "format_message_as_text", + "format_conversation_as_text", + "pretty_print_message", + "pretty_print_conversation", + "validate_message", + "validate_conversation", + "is_valid_message", + "is_valid_conversation", + "create_text_message", + "create_function_call", + "create_function_response", + "create_error_message", + "format_function_params", + "conversation_to_messages", + "skill", + "agent", # Workflow - 'Flow', - 'WorkflowContext', - 'WorkflowStep', - 'QueryStep', - 'AutoRouteStep', - 'FunctionStep', - 'ConditionalBranch', - 'ConditionStep', - 'ParallelStep', - 'ParallelBuilder', - 'StepType', - + "Flow", + "WorkflowContext", + "WorkflowStep", + "QueryStep", + "AutoRouteStep", + "FunctionStep", + "ConditionalBranch", + "ConditionStep", + "ParallelStep", + "ParallelBuilder", + "StepType", # MCP - 'MCPClient', - 'MCPError', - 'MCPConnectionError', - 'MCPTimeoutError', - 'MCPToolError', - 'MCPTools', - 'MCPEnabledAgent', - 'FastMCP', - 'MCPResponse', - 'text_response', - 'error_response', - 'image_response', - 'multi_content_response', - 'MCPContentType', - 'FastMCPAgent', - 'A2AMCPAgent', - 'create_proxy_server', - 'create_fastapi_app', - + "MCPClient", + "MCPError", + "MCPConnectionError", + "MCPTimeoutError", + "MCPToolError", + "MCPTools", + "MCPEnabledAgent", + "FastMCP", + "MCPResponse", + "text_response", + "error_response", + "image_response", + "multi_content_response", + "MCPContentType", + "FastMCPAgent", + "A2AMCPAgent", + "create_proxy_server", + "create_fastapi_app", # LangChain Integration (always included) - 'to_a2a_server', - 'to_langchain_agent', - 'to_mcp_server', - 'to_langchain_tool', - 'LangChainIntegrationError', - 'LangChainNotInstalledError', - 'LangChainToolConversionError', - 'MCPToolConversionError', - 'LangChainAgentConversionError', - 'A2AAgentConversionError', + "to_a2a_server", + "to_langchain_agent", + "to_mcp_server", + "to_langchain_tool", + "LangChainIntegrationError", + "LangChainNotInstalledError", + "LangChainToolConversionError", + "MCPToolConversionError", + "LangChainAgentConversionError", + "A2AAgentConversionError", ] # Conditionally add LLM clients/servers if HAS_LLM_CLIENTS: - __all__.extend(['OpenAIA2AClient', 'AnthropicA2AClient']) + __all__.extend(["OpenAIA2AClient", "OllamaA2AClient", "AnthropicA2AClient"]) if HAS_LLM_SERVERS: - __all__.extend(['OpenAIA2AServer', 'AnthropicA2AServer', 'BedrockA2AServer']) + __all__.extend( + ["OpenAIA2AServer", "OllamaA2AServer", "AnthropicA2AServer", "BedrockA2AServer"] + ) # Conditionally add docs if HAS_DOCS: - __all__.extend(['generate_a2a_docs', 'generate_html_docs']) + __all__.extend(["generate_a2a_docs", "generate_html_docs"]) # Conditionally add CLI if HAS_CLI: - __all__.append('cli_main') \ No newline at end of file + __all__.append("cli_main") diff --git a/python_a2a/client/__init__.py b/python_a2a/client/__init__.py index 1cccc34..35df225 100644 --- a/python_a2a/client/__init__.py +++ b/python_a2a/client/__init__.py @@ -7,7 +7,7 @@ from .http import A2AClient # Import LLM-specific clients -from .llm import OpenAIA2AClient, AnthropicA2AClient +from .llm import OpenAIA2AClient, OllamaA2AClient, AnthropicA2AClient # Import enhanced components from .network import AgentNetwork @@ -16,11 +16,12 @@ # Make everything available at the client level __all__ = [ - 'BaseA2AClient', - 'A2AClient', - 'OpenAIA2AClient', - 'AnthropicA2AClient', - 'AgentNetwork', - 'AIAgentRouter', - 'StreamingClient' -] \ No newline at end of file + "BaseA2AClient", + "A2AClient", + "OpenAIA2AClient", + "OllamaA2AClient", + "AnthropicA2AClient", + "AgentNetwork", + "AIAgentRouter", + "StreamingClient", +] diff --git a/python_a2a/client/llm/__init__.py b/python_a2a/client/llm/__init__.py index fc14b85..f50ada6 100644 --- a/python_a2a/client/llm/__init__.py +++ b/python_a2a/client/llm/__init__.py @@ -4,12 +4,14 @@ # Import and re-export LLM clients from .openai import OpenAIA2AClient +from .ollama import OllamaA2AClient from .anthropic import AnthropicA2AClient from .bedrock import BedrockA2AClient # Make all clients available at the llm level __all__ = [ - 'OpenAIA2AClient', - 'AnthropicA2AClient', - 'BedrockA2AClient' -] \ No newline at end of file + "OpenAIA2AClient", + "OllamaA2AClient", + "AnthropicA2AClient", + "BedrockA2AClient", +] diff --git a/python_a2a/client/llm/ollama.py b/python_a2a/client/llm/ollama.py new file mode 100644 index 0000000..b6c9b5c --- /dev/null +++ b/python_a2a/client/llm/ollama.py @@ -0,0 +1,86 @@ +""" +Ollama-based client implementation for the A2A protocol. +""" + +import requests +from typing import Optional, List, Dict, Any + +try: + from openai import OpenAI +except ImportError: + OpenAI = None + +from .openai import OpenAIA2AClient +from ...exceptions import A2AImportError, A2AConnectionError + + +class OllamaA2AClient(OpenAIA2AClient): + """A2A client that uses OpenAI's API on Ollama server to process messages.""" + + def __init__( + self, + api_url: str, + model: str, + temperature: float = 0.7, + system_prompt: Optional[str] = None, + functions: Optional[List[Dict[str, Any]]] = None, + ): + """ + Initialize the Ollama A2A client + + Args: + api_url: Ollama API URL + model: Ollama model to use + temperature: Generation temperature (default: 0.7) + system_prompt: Optional system prompt for all conversations + functions: Optional list of function definitions for function calling + + Raises: + A2AImportError: If the ollama package is not installed + """ + super().__init__( + model=model, + api_key=None, + temperature=temperature, + system_prompt=system_prompt, + functions=functions, + ) + + # Initialize OpenAI compatible client + self.__api_url = api_url + + try: + self.__models = self.list_models() + except Exception as err: + raise A2AImportError( + f"Ollama API is not available. Please check your installation. {err}" + ) + + if model not in self.__models: + raise A2AImportError(f"Model '{model}' is not available in the Ollama API.") + + self.client = OpenAI(base_url=f"{api_url}/v1", api_key="ollama") + + def list_models(self) -> List[str]: + """ + List available models from the Ollama API. + + Returns: + List of model names. + """ + try: + result = requests.get(f"{self.__api_url}/api/tags") + jsondata = result.json() + return [m.get("model") for m in jsondata.get("models") if m.get("model")] + except requests.RequestException as err: + raise A2AConnectionError( + f"Failed to connect to Ollama API at {self.__api_url}" + ) from err + except ValueError as err: + raise ValueError( + f"Failed to parse response from Ollama API at {self.__api_url}" + ) from err + except Exception as err: + raise A2AConnectionError( + f"An unexpected error occurred while connecting to Ollama API at {self.__api_url}" + ) from err diff --git a/python_a2a/client/llm/openai.py b/python_a2a/client/llm/openai.py index 4573950..644fd91 100644 --- a/python_a2a/client/llm/openai.py +++ b/python_a2a/client/llm/openai.py @@ -11,7 +11,12 @@ OpenAI = None from ...models.message import Message, MessageRole -from ...models.content import TextContent, FunctionCallContent, FunctionResponseContent, FunctionParameter +from ...models.content import ( + TextContent, + FunctionCallContent, + FunctionResponseContent, + FunctionParameter, +) from ...models.conversation import Conversation from ..base import BaseA2AClient from ...exceptions import A2AImportError, A2AConnectionError @@ -19,25 +24,25 @@ class OpenAIA2AClient(BaseA2AClient): """A2A client that uses OpenAI's API to process messages.""" - + def __init__( self, api_key: str, model: str = "gpt-3.5-turbo", temperature: float = 0.7, system_prompt: Optional[str] = None, - functions: Optional[List[Dict[str, Any]]] = None + functions: Optional[List[Dict[str, Any]]] = None, ): """ Initialize the OpenAI A2A client - + Args: api_key: OpenAI API key model: OpenAI model to use (default: "gpt-3.5-turbo") temperature: Generation temperature (default: 0.7) system_prompt: Optional system prompt for all conversations functions: Optional list of function definitions for function calling - + Raises: A2AImportError: If the openai package is not installed """ @@ -46,93 +51,100 @@ def __init__( "OpenAI package is not installed. " "Install it with 'pip install openai'" ) - + self.api_key = api_key self.model = model self.temperature = temperature self.system_prompt = system_prompt or "You are a helpful assistant." self.functions = functions self.tools = self._convert_functions_to_tools() if functions else None - - # Initialize OpenAI client - self.client = OpenAI(api_key=api_key) - + + # Initialize OpenAI client only if the API key is provided + if api_key: + try: + self.client = OpenAI(api_key=api_key) + except Exception as err: + raise A2AConnectionError(f"Failed to connect to OpenAI API: {str(err)}") + # Store message history for conversations self._conversation_histories = {} - + def _convert_functions_to_tools(self): """Convert functions to the tools format used by newer OpenAI models""" if not self.functions: return None - + tools = [] for func in self.functions: - tools.append({ - "type": "function", - "function": func - }) + tools.append({"type": "function", "function": func}) return tools - + def send_message(self, message: Message) -> Message: """ Send a message to OpenAI's API and return the response as an A2A message - + Args: message: The A2A message to send - + Returns: The response as an A2A message - + Raises: A2AConnectionError: If connection to OpenAI fails """ try: # Create OpenAI message format openai_messages = [{"role": "system", "content": self.system_prompt}] - + # If this is part of a conversation, retrieve history conversation_id = message.conversation_id if conversation_id and conversation_id in self._conversation_histories: openai_messages = self._conversation_histories[conversation_id].copy() - + # Add the current message if message.content.type == "text": - openai_messages.append({ - "role": "user" if message.role == MessageRole.USER else "assistant", - "content": message.content.text - }) + openai_messages.append( + { + "role": ( + "user" if message.role == MessageRole.USER else "assistant" + ), + "content": message.content.text, + } + ) elif message.content.type == "function_call": # Convert function call to string representation - params_str = ", ".join([f"{p.name}={p.value}" for p in message.content.parameters]) + params_str = ", ".join( + [f"{p.name}={p.value}" for p in message.content.parameters] + ) text = f"Call function {message.content.name}({params_str})" openai_messages.append({"role": "user", "content": text}) elif message.content.type == "function_response": # Format function response in OpenAI's expected format - openai_messages.append({ - "role": "function", - "name": message.content.name, - "content": json.dumps(message.content.response) - }) + openai_messages.append( + { + "role": "function", + "name": message.content.name, + "content": json.dumps(message.content.response), + } + ) elif message.content.type == "error": # Convert error to text - openai_messages.append({ - "role": "user", - "content": f"Error: {message.content.message}" - }) + openai_messages.append( + {"role": "user", "content": f"Error: {message.content.message}"} + ) else: # Default case for unknown content types - openai_messages.append({ - "role": "user", - "content": str(message.content) - }) - + openai_messages.append( + {"role": "user", "content": str(message.content)} + ) + # Prepare API call parameters kwargs = { "model": self.model, "messages": openai_messages, "temperature": self.temperature, } - + # Add functions or tools if provided if self.tools: # Newer models use tools @@ -142,14 +154,14 @@ def send_message(self, message: Message) -> Message: # Older models use functions kwargs["functions"] = self.functions kwargs["function_call"] = "auto" - + # Call OpenAI API response = self.client.chat.completions.create(**kwargs) - + # Parse response choice = response.choices[0] response_message = choice.message - + # Update conversation history if we have a conversation ID if conversation_id: if conversation_id not in self._conversation_histories: @@ -157,41 +169,56 @@ def send_message(self, message: Message) -> Message: self._conversation_histories[conversation_id] = [ {"role": "system", "content": self.system_prompt} ] - + # Add the user message to history if message.content.type == "text": - self._conversation_histories[conversation_id].append({ - "role": "user" if message.role == MessageRole.USER else "assistant", - "content": message.content.text - }) + self._conversation_histories[conversation_id].append( + { + "role": ( + "user" + if message.role == MessageRole.USER + else "assistant" + ), + "content": message.content.text, + } + ) elif message.content.type == "function_response": - self._conversation_histories[conversation_id].append({ - "role": "function", - "name": message.content.name, - "content": json.dumps(message.content.response) - }) - + self._conversation_histories[conversation_id].append( + { + "role": "function", + "name": message.content.name, + "content": json.dumps(message.content.response), + } + ) + # Add the assistant's response to history if hasattr(response_message, "content") and response_message.content: - self._conversation_histories[conversation_id].append({ - "role": "assistant", - "content": response_message.content - }) - elif hasattr(response_message, "function_call") and response_message.function_call: + self._conversation_histories[conversation_id].append( + {"role": "assistant", "content": response_message.content} + ) + elif ( + hasattr(response_message, "function_call") + and response_message.function_call + ): # Add function call to history - self._conversation_histories[conversation_id].append({ - "role": "assistant", - "function_call": { - "name": response_message.function_call.name, - "arguments": response_message.function_call.arguments + self._conversation_histories[conversation_id].append( + { + "role": "assistant", + "function_call": { + "name": response_message.function_call.name, + "arguments": response_message.function_call.arguments, + }, } - }) - + ) + # Convert to A2A message format - if hasattr(response_message, "function_call") and response_message.function_call: + if ( + hasattr(response_message, "function_call") + and response_message.function_call + ): # Handle function call response function_call = response_message.function_call - + try: # Parse arguments as JSON args = json.loads(function_call.arguments) @@ -202,20 +229,21 @@ def send_message(self, message: Message) -> Message: except: # Fallback for non-JSON arguments parameters = [ - FunctionParameter(name="arguments", value=function_call.arguments) + FunctionParameter( + name="arguments", value=function_call.arguments + ) ] - + # Create function call message return Message( content=FunctionCallContent( - name=function_call.name, - parameters=parameters + name=function_call.name, parameters=parameters ), role=MessageRole.AGENT, parent_message_id=message.message_id, - conversation_id=message.conversation_id + conversation_id=message.conversation_id, ) - + # Handle tool calls in newer models tool_calls = getattr(response_message, "tool_calls", None) if tool_calls: @@ -231,94 +259,103 @@ def send_message(self, message: Message) -> Message: except: # Fallback for non-JSON arguments parameters = [ - FunctionParameter(name="arguments", value=tool_call.function.arguments) + FunctionParameter( + name="arguments", value=tool_call.function.arguments + ) ] - + # Create function call message return Message( content=FunctionCallContent( - name=tool_call.function.name, - parameters=parameters + name=tool_call.function.name, parameters=parameters ), role=MessageRole.AGENT, parent_message_id=message.message_id, - conversation_id=message.conversation_id + conversation_id=message.conversation_id, ) - + # Regular text response return Message( content=TextContent(text=response_message.content or ""), role=MessageRole.AGENT, parent_message_id=message.message_id, - conversation_id=message.conversation_id + conversation_id=message.conversation_id, ) - + except Exception as e: # Create error message return Message( content=TextContent(text=f"Error from OpenAI API: {str(e)}"), role=MessageRole.AGENT, parent_message_id=message.message_id, - conversation_id=message.conversation_id + conversation_id=message.conversation_id, ) - + def send_conversation(self, conversation: Conversation) -> Conversation: """ Send a full conversation to OpenAI's API and get an updated conversation - + Args: conversation: The A2A conversation to send - + Returns: The updated conversation with the response - + Raises: A2AConnectionError: If connection to OpenAI fails """ if not conversation.messages: # Empty conversation, return as is return conversation - + try: # Initialize OpenAI message format with system prompt openai_messages = [{"role": "system", "content": self.system_prompt}] - + # Add all messages from the conversation for msg in conversation.messages: if msg.content.type == "text": - openai_messages.append({ - "role": "user" if msg.role == MessageRole.USER else "assistant", - "content": msg.content.text - }) + openai_messages.append( + { + "role": ( + "user" if msg.role == MessageRole.USER else "assistant" + ), + "content": msg.content.text, + } + ) elif msg.content.type == "function_call": # Convert function call to string - params_str = ", ".join([f"{p.name}={p.value}" for p in msg.content.parameters]) + params_str = ", ".join( + [f"{p.name}={p.value}" for p in msg.content.parameters] + ) text = f"Call function {msg.content.name}({params_str})" - + role = "user" if msg.role == MessageRole.USER else "assistant" openai_messages.append({"role": role, "content": text}) elif msg.content.type == "function_response": # Format function response for OpenAI - openai_messages.append({ - "role": "function", - "name": msg.content.name, - "content": json.dumps(msg.content.response) - }) - + openai_messages.append( + { + "role": "function", + "name": msg.content.name, + "content": json.dumps(msg.content.response), + } + ) + # Get conversation ID for tracking history conversation_id = conversation.conversation_id - + # Store the conversation in history if conversation_id: self._conversation_histories[conversation_id] = openai_messages.copy() - + # Prepare API call parameters kwargs = { "model": self.model, "messages": openai_messages, "temperature": self.temperature, } - + # Add functions or tools if provided if self.tools: kwargs["tools"] = self.tools @@ -326,28 +363,34 @@ def send_conversation(self, conversation: Conversation) -> Conversation: elif self.functions: kwargs["functions"] = self.functions kwargs["function_call"] = "auto" - + # Call OpenAI API response = self.client.chat.completions.create(**kwargs) - + # Parse response choice = response.choices[0] response_message = choice.message - + # Add to conversation history - if conversation_id and hasattr(response_message, "content") and response_message.content: - self._conversation_histories[conversation_id].append({ - "role": "assistant", - "content": response_message.content - }) - + if ( + conversation_id + and hasattr(response_message, "content") + and response_message.content + ): + self._conversation_histories[conversation_id].append( + {"role": "assistant", "content": response_message.content} + ) + # Get the last message in the conversation as parent last_message = conversation.messages[-1] - + # Create a new message based on the response - if hasattr(response_message, "function_call") and response_message.function_call: + if ( + hasattr(response_message, "function_call") + and response_message.function_call + ): function_call = response_message.function_call - + # Parse arguments as JSON try: args = json.loads(function_call.arguments) @@ -358,22 +401,25 @@ def send_conversation(self, conversation: Conversation) -> Conversation: except: # Fallback for non-JSON arguments parameters = [ - FunctionParameter(name="arguments", value=function_call.arguments) + FunctionParameter( + name="arguments", value=function_call.arguments + ) ] - + # Add function call message to conversation a2a_message = Message( content=FunctionCallContent( - name=function_call.name, - parameters=parameters + name=function_call.name, parameters=parameters ), role=MessageRole.AGENT, parent_message_id=last_message.message_id, - conversation_id=conversation_id + conversation_id=conversation_id, ) conversation.add_message(a2a_message) - - elif hasattr(response_message, "tool_calls") and response_message.tool_calls: + + elif ( + hasattr(response_message, "tool_calls") and response_message.tool_calls + ): # Handle tool calls tool_call = response_message.tool_calls[0] if tool_call.type == "function": @@ -387,18 +433,19 @@ def send_conversation(self, conversation: Conversation) -> Conversation: except: # Fallback for non-JSON arguments parameters = [ - FunctionParameter(name="arguments", value=tool_call.function.arguments) + FunctionParameter( + name="arguments", value=tool_call.function.arguments + ) ] - + # Add function call message to conversation a2a_message = Message( content=FunctionCallContent( - name=tool_call.function.name, - parameters=parameters + name=tool_call.function.name, parameters=parameters ), role=MessageRole.AGENT, parent_message_id=last_message.message_id, - conversation_id=conversation_id + conversation_id=conversation_id, ) conversation.add_message(a2a_message) else: @@ -407,56 +454,57 @@ def send_conversation(self, conversation: Conversation) -> Conversation: content=TextContent(text=response_message.content or ""), role=MessageRole.AGENT, parent_message_id=last_message.message_id, - conversation_id=conversation_id + conversation_id=conversation_id, ) conversation.add_message(a2a_message) - + return conversation - + except Exception as e: # Add error message to conversation error_msg = f"Error from OpenAI API: {str(e)}" - + # Use the last message as parent if available - parent_id = conversation.messages[-1].message_id if conversation.messages else None - + parent_id = ( + conversation.messages[-1].message_id if conversation.messages else None + ) + conversation.create_error_message(error_msg, parent_message_id=parent_id) return conversation - + def ask(self, query: str) -> str: """ Simple helper for text-based queries - + Args: query: Text query to send - + Returns: Text response from the model """ # Create message - message = Message( - content=TextContent(text=query), - role=MessageRole.USER - ) - + message = Message(content=TextContent(text=query), role=MessageRole.USER) + # Send message and get response response = self.send_message(message) - + # Extract text from response if response.content.type == "text": return response.content.text elif response.content.type == "function_call": # Format function call as text - params_str = ", ".join([f"{p.name}={p.value}" for p in response.content.parameters]) + params_str = ", ".join( + [f"{p.name}={p.value}" for p in response.content.parameters] + ) return f"Function call: {response.content.name}({params_str})" else: # Default case for other content types return str(response.content) - + def clear_conversation_history(self, conversation_id: str = None): """ Clear conversation history for a specific conversation or all conversations - + Args: conversation_id: ID of conversation to clear, or None to clear all """ @@ -468,4 +516,4 @@ def clear_conversation_history(self, conversation_id: str = None): ] else: # Clear all conversation histories - self._conversation_histories = {} \ No newline at end of file + self._conversation_histories = {} diff --git a/python_a2a/client/streaming.py b/python_a2a/client/streaming.py index 866fd13..73d3a7d 100644 --- a/python_a2a/client/streaming.py +++ b/python_a2a/client/streaming.py @@ -25,7 +25,7 @@ class StreamingChunk: """ A structured representation of a streaming chunk from an A2A agent. - + Attributes: content: The content of the chunk is_last: Whether this is the last chunk @@ -33,18 +33,18 @@ class StreamingChunk: index: The sequence index of this chunk (if provided) event_type: The type of event (default: 'chunk') """ - + def __init__( - self, + self, content: Any, is_last: bool = False, append: bool = True, index: Optional[int] = None, - event_type: str = 'chunk' + event_type: str = "chunk", ): """ Initialize a streaming chunk. - + Args: content: The content of the chunk is_last: Whether this is the last chunk @@ -57,72 +57,78 @@ def __init__( self.append = append self.index = index self.event_type = event_type - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'StreamingChunk': + def from_dict(cls, data: Dict[str, Any]) -> "StreamingChunk": """ Create a StreamingChunk from a dictionary. - + Args: data: Dictionary representation of a streaming chunk - + Returns: A StreamingChunk object """ - content = data.get('content', '') - + content = data.get("content", "") + # Look for text content in different formats - if isinstance(content, dict) and 'text' in content: - content = content['text'] - elif isinstance(content, dict) and 'parts' in content and isinstance(content['parts'], list): + if isinstance(content, dict) and "text" in content: + content = content["text"] + elif ( + isinstance(content, dict) + and "parts" in content + and isinstance(content["parts"], list) + ): # Extract text from parts array format - for part in content['parts']: - if isinstance(part, dict) and part.get('type') == 'text' and 'text' in part: - content = part['text'] + for part in content["parts"]: + if ( + isinstance(part, dict) + and part.get("type") == "text" + and "text" in part + ): + content = part["text"] break - + return cls( content=content, - is_last=data.get('lastChunk', False), - append=data.get('append', True), - index=data.get('index'), - event_type=data.get('event', 'chunk') + is_last=data.get("lastChunk", False), + append=data.get("append", True), + index=data.get("index"), + event_type=data.get("event", "chunk"), ) class StreamingClient(BaseA2AClient): """ Client for streaming responses from A2A-compatible agents. - + This client enhances the standard A2A client with streaming response capabilities, allowing for real-time processing of agent responses. """ - + def __init__( - self, - url: str, - headers: Optional[Dict[str, str]] = None, - timeout: int = 30 + self, url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 30 ): """ Initialize a streaming client. - + Args: url: Base URL of the A2A agent headers: Optional HTTP headers to include in requests timeout: Request timeout in seconds """ - self.url = url.rstrip('/') + self.url = url.rstrip("/") self.headers = headers or {} self.timeout = timeout - + # Ensure content type is set for JSON - if 'Content-Type' not in self.headers: - self.headers['Content-Type'] = 'application/json' - + if "Content-Type" not in self.headers: + self.headers["Content-Type"] = "application/json" + # Check if SSE support is available try: import aiohttp + self._has_aiohttp = True except ImportError: self._has_aiohttp = False @@ -130,44 +136,43 @@ def __init__( "aiohttp not installed. Streaming will use polling instead. " "Install aiohttp for better streaming support." ) - + # Flag for checking if the agent supports streaming self._supports_streaming = None - + async def check_streaming_support(self) -> bool: """ Check if the agent supports streaming. - + Returns: True if streaming is supported, False otherwise """ if self._supports_streaming is not None: return self._supports_streaming - + # Try to fetch agent metadata to check for streaming capability try: # Check if aiohttp is available if not self._has_aiohttp: self._supports_streaming = False return False - + # Import to avoid circular imports from ..models import AgentCard - + # Try to load agent card async with self._create_session() as session: # Create headers specifically for JSON content negotiation json_headers = {"Accept": "application/json"} - + # First attempt with primary endpoint async with session.get( - f"{self.url}/agent.json", - headers=json_headers + f"{self.url}/agent.json", headers=json_headers ) as response: if response.status == 200: # Check content type to ensure we got JSON - content_type = response.headers.get('Content-Type', '') - if 'application/json' in content_type: + content_type = response.headers.get("Content-Type", "") + if "application/json" in content_type: # Parse JSON directly data = await response.json() else: @@ -176,26 +181,29 @@ async def check_streaming_support(self) -> bool: text = await response.text() data = self._extract_json_from_response(text) except Exception as e: - logger.warning(f"Failed to extract JSON from response: {e}") + logger.warning( + f"Failed to extract JSON from response: {e}" + ) data = {} - + # Check capabilities self._supports_streaming = ( - isinstance(data, dict) and - isinstance(data.get("capabilities"), dict) and - data.get("capabilities", {}).get("streaming", False) + isinstance(data, dict) + and isinstance(data.get("capabilities"), dict) + and data.get("capabilities", {}).get("streaming", False) ) else: # Try alternate endpoint alternate_url = f"{self.url}/a2a/agent.json" async with session.get( - alternate_url, - headers=json_headers + alternate_url, headers=json_headers ) as alt_response: if alt_response.status == 200: # Check content type to ensure we got JSON - content_type = alt_response.headers.get('Content-Type', '') - if 'application/json' in content_type: + content_type = alt_response.headers.get( + "Content-Type", "" + ) + if "application/json" in content_type: # Parse JSON directly data = await alt_response.json() else: @@ -204,24 +212,28 @@ async def check_streaming_support(self) -> bool: text = await alt_response.text() data = self._extract_json_from_response(text) except Exception as e: - logger.warning(f"Failed to extract JSON from response: {e}") + logger.warning( + f"Failed to extract JSON from response: {e}" + ) data = {} - + # Check capabilities self._supports_streaming = ( - isinstance(data, dict) and - isinstance(data.get("capabilities"), dict) and - data.get("capabilities", {}).get("streaming", False) + isinstance(data, dict) + and isinstance(data.get("capabilities"), dict) + and data.get("capabilities", {}).get( + "streaming", False + ) ) else: self._supports_streaming = False - + except Exception as e: logger.warning(f"Error checking streaming support: {e}") self._supports_streaming = False - + return self._supports_streaming - + def _create_session(self): """Create an aiohttp session.""" if not self._has_aiohttp: @@ -229,23 +241,23 @@ def _create_session(self): "aiohttp is required for streaming. " "Install it with 'pip install aiohttp'." ) - + import aiohttp + return aiohttp.ClientSession( - headers=self.headers, - timeout=aiohttp.ClientTimeout(total=self.timeout) + headers=self.headers, timeout=aiohttp.ClientTimeout(total=self.timeout) ) - + def send_message(self, message: Message) -> Message: """ Send a message to an A2A-compatible agent (synchronous). - + This method overrides the BaseA2AClient.send_message method to provide backward compatibility. - + Args: message: The A2A message to send - + Returns: The agent's response as an A2A message """ @@ -256,19 +268,19 @@ def send_message(self, message: Message) -> Message: # No event loop in this thread, create one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + return loop.run_until_complete(self.send_message_async(message)) - + def send_conversation(self, conversation: Conversation) -> Conversation: """ Send a conversation to an A2A-compatible agent (synchronous). - + This method overrides the BaseA2AClient.send_conversation method to provide backward compatibility. - + Args: conversation: The conversation to send - + Returns: The updated conversation with the agent's response """ @@ -279,44 +291,44 @@ def send_conversation(self, conversation: Conversation) -> Conversation: # No event loop in this thread, create one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + return loop.run_until_complete(self.send_conversation_async(conversation)) - + async def send_conversation_async(self, conversation: Conversation) -> Conversation: """ Send a conversation to an A2A-compatible agent (asynchronous). - + Args: conversation: The conversation to send - + Returns: The updated conversation with the agent's response """ # For simplicity, extract the last message and send it if not conversation.messages: raise ValueError("Cannot send an empty conversation") - + # Get last message (typically from the user) last_message = conversation.messages[-1] - + # Send the message response = await self.send_message_async(last_message) - + # Add the response to the conversation conversation.add_message(response) - + return conversation - + async def send_message_async(self, message: Message) -> Message: """ Send a message to an A2A-compatible agent (asynchronous). - + Args: message: The A2A message to send - + Returns: The agent's response as an A2A message - + Raises: A2AConnectionError: If connection to the agent fails A2AResponseError: If the agent returns an invalid response @@ -325,113 +337,109 @@ async def send_message_async(self, message: Message) -> Message: if not self._has_aiohttp: # Fall back to synchronous requests if aiohttp not available import requests + response = requests.post( self.url, json=message.to_dict(), headers=self.headers, - timeout=self.timeout + timeout=self.timeout, ) response.raise_for_status() return Message.from_dict(response.json()) - + # Asynchronous request with aiohttp async with self._create_session() as session: - async with session.post( - self.url, - json=message.to_dict() - ) as response: + async with session.post(self.url, json=message.to_dict()) as response: # Handle HTTP errors if response.status >= 400: error_text = await response.text() raise A2AConnectionError( f"HTTP error {response.status}: {error_text}" ) - + # Parse the response try: data = await response.json() return Message.from_dict(data) except ValueError as e: raise A2AResponseError(f"Invalid response from agent: {str(e)}") - + except Exception as e: if isinstance(e, (A2AConnectionError, A2AResponseError)): raise - + # Create an error message as response return Message( content=TextContent(text=f"Error: {str(e)}"), role=MessageRole.SYSTEM, parent_message_id=message.message_id, - conversation_id=message.conversation_id + conversation_id=message.conversation_id, ) - + async def stream_response( - self, + self, message: Message, - chunk_callback: Optional[Callable[[Union[str, Dict]], None]] = None + chunk_callback: Optional[Callable[[Union[str, Dict]], None]] = None, ) -> AsyncGenerator[Union[str, Dict], None]: """ Stream a response from an A2A-compatible agent. - + Args: message: The A2A message to send chunk_callback: Optional callback function for each chunk - + Yields: Response chunks from the agent - + Raises: A2AConnectionError: If connection to the agent fails A2AResponseError: If the agent returns an invalid response """ # Check if streaming is supported supports_streaming = await self.check_streaming_support() - + if not supports_streaming: # Fall back to non-streaming if not supported response = await self.send_message_async(message) - + # Get text from response if hasattr(response.content, "text"): result = response.content.text else: result = str(response.content) - + # Yield the entire response as one chunk if chunk_callback: chunk_callback(result) yield result return - + if not self._has_aiohttp: # Fall back to non-streaming if aiohttp not available response = await self.send_message_async(message) - + # Get text from response if hasattr(response.content, "text"): result = response.content.text else: result = str(response.content) - + # Yield the entire response as one chunk if chunk_callback: chunk_callback(result) yield result return - + # Real streaming implementation with aiohttp try: # Set up streaming request async with self._create_session() as session: headers = dict(self.headers) # Add headers to request server-sent events - headers['Accept'] = 'text/event-stream' - + headers["Accept"] = "text/event-stream" + async with session.post( - f"{self.url}/stream", - json=message.to_dict(), - headers=headers + f"{self.url}/stream", json=message.to_dict(), headers=headers ) as response: # Handle HTTP errors if response.status >= 400: @@ -439,67 +447,65 @@ async def stream_response( raise A2AConnectionError( f"HTTP error {response.status}: {error_text}" ) - + # Process the streaming response async for chunk in self._process_stream(response, chunk_callback): yield chunk - + except Exception as e: if isinstance(e, (A2AConnectionError, A2AResponseError)): raise - + # Fall back to non-streaming for other errors logger.warning(f"Error in streaming, falling back to non-streaming: {e}") response = await self.send_message_async(message) - + # Get text from response if hasattr(response.content, "text"): result = response.content.text else: result = str(response.content) - + # Yield the entire response as one chunk if chunk_callback: chunk_callback(result) yield result - + async def create_task(self, message: Union[str, Message]) -> Task: """ Create a task from a message. - + Args: message: Text message or Message object - + Returns: The created task """ # Convert to Message if needed if isinstance(message, str): message_obj = Message( - content=TextContent(text=message), - role=MessageRole.USER + content=TextContent(text=message), role=MessageRole.USER ) else: message_obj = message - + # Create a task task = Task( - id=str(id(message_obj)), # Simple unique ID - message=message_obj.to_dict() + id=str(id(message_obj)), message=message_obj.to_dict() # Simple unique ID ) - + return task - + async def send_task(self, task: Task) -> Task: """ Send a task to an A2A-compatible agent. - + Args: task: The task to send - + Returns: The updated task with the agent's response - + Raises: A2AConnectionError: If connection to the agent fails A2AResponseError: If the agent returns an invalid response @@ -508,7 +514,7 @@ async def send_task(self, task: Task) -> Task: if not self._has_aiohttp: # Fall back to synchronous requests if aiohttp not available import requests - + # Try POST to /tasks/send endpoint try: response = requests.post( @@ -517,10 +523,10 @@ async def send_task(self, task: Task) -> Task: "jsonrpc": "2.0", "id": 1, "method": "tasks/send", - "params": task.to_dict() + "params": task.to_dict(), }, headers=self.headers, - timeout=self.timeout + timeout=self.timeout, ) response.raise_for_status() result = response.json().get("result", {}) @@ -533,15 +539,15 @@ async def send_task(self, task: Task) -> Task: "jsonrpc": "2.0", "id": 1, "method": "tasks/send", - "params": task.to_dict() + "params": task.to_dict(), }, headers=self.headers, - timeout=self.timeout + timeout=self.timeout, ) response.raise_for_status() result = response.json().get("result", {}) return Task.from_dict(result) - + # Asynchronous request with aiohttp async with self._create_session() as session: # Try POST to /tasks/send endpoint @@ -552,19 +558,19 @@ async def send_task(self, task: Task) -> Task: "jsonrpc": "2.0", "id": 1, "method": "tasks/send", - "params": task.to_dict() - } + "params": task.to_dict(), + }, ) as response: # Handle HTTP errors if response.status >= 400: # Try alternate endpoint raise Exception("First endpoint failed") - + # Parse the response data = await response.json() result = data.get("result", {}) return Task.from_dict(result) - + except Exception: # Try alternate endpoint async with session.post( @@ -573,8 +579,8 @@ async def send_task(self, task: Task) -> Task: "jsonrpc": "2.0", "id": 1, "method": "tasks/send", - "params": task.to_dict() - } + "params": task.to_dict(), + }, ) as response: # Handle HTTP errors if response.status >= 400: @@ -582,33 +588,30 @@ async def send_task(self, task: Task) -> Task: raise A2AConnectionError( f"HTTP error {response.status}: {error_text}" ) - + # Parse the response data = await response.json() result = data.get("result", {}) return Task.from_dict(result) - + except Exception as e: if isinstance(e, (A2AConnectionError, A2AResponseError)): raise - + # Create an error task as response - task.status = TaskStatus( - state=TaskState.FAILED, - message={"error": str(e)} - ) + task.status = TaskStatus(state=TaskState.FAILED, message={"error": str(e)}) return task - + async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]: """ Send a task and subscribe to streaming updates using tasks/sendSubscribe. - + Args: task: The task to send and subscribe to - + Yields: Task updates as they arrive - + Raises: A2AConnectionError: If connection to the agent fails A2AResponseError: If the agent returns an invalid response @@ -619,7 +622,7 @@ async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]: "aiohttp is required for tasks_send_subscribe. " "Install it with 'pip install aiohttp'." ) - + # Check if streaming is supported supports_streaming = await self.check_streaming_support() if not supports_streaming: @@ -627,98 +630,104 @@ async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]: task_result = await self.send_task(task) yield task_result return - + # Real streaming implementation with aiohttp try: # Set up streaming request async with self._create_session() as session: headers = dict(self.headers) # Add headers to request server-sent events - headers['Accept'] = 'text/event-stream' - + headers["Accept"] = "text/event-stream" + # Use the direct task instead of JsonRPC format for better compatibility request_data = task.to_dict() - + # Add debug logging logger.debug(f"Sending task streaming request with task ID: {task.id}") - + # Store the endpoint URLs to try endpoints_to_try = [] - + # If a custom stream_task_url is set, use it first - if hasattr(self, '_stream_task_url') and self._stream_task_url: - logger.debug(f"Using custom task streaming URL: {self._stream_task_url}") + if hasattr(self, "_stream_task_url") and self._stream_task_url: + logger.debug( + f"Using custom task streaming URL: {self._stream_task_url}" + ) endpoints_to_try.append(self._stream_task_url) - + # Then try standard endpoints - endpoints_to_try.extend([ - f"{self.url}/a2a/tasks/stream", # Try A2A-specific endpoint first - f"{self.url}/tasks/stream", # Then standard tasks endpoint - f"{self.url}/stream" # Finally fallback to regular stream endpoint - ]) - + endpoints_to_try.extend( + [ + f"{self.url}/a2a/tasks/stream", # Try A2A-specific endpoint first + f"{self.url}/tasks/stream", # Then standard tasks endpoint + f"{self.url}/stream", # Finally fallback to regular stream endpoint + ] + ) + response = None last_error = None - + # Try each endpoint in order for endpoint_url in endpoints_to_try: try: logger.debug(f"Trying task streaming endpoint: {endpoint_url}") - + # Close previous response if we had one if response: await response.release() - + # Send the request to this endpoint response = await session.post( - endpoint_url, - json=request_data, - headers=headers + endpoint_url, json=request_data, headers=headers ) - + # Check for success if response.status < 400: - logger.debug(f"Successfully connected to endpoint: {endpoint_url}") + logger.debug( + f"Successfully connected to endpoint: {endpoint_url}" + ) break - + # Store error for retry error_text = await response.text() - last_error = A2AConnectionError(f"HTTP error {response.status}: {error_text}") - + last_error = A2AConnectionError( + f"HTTP error {response.status}: {error_text}" + ) + except Exception as req_error: # Log the error and continue to next endpoint logger.debug(f"Error with endpoint {endpoint_url}: {req_error}") last_error = req_error - + # If we didn't get a successful response, raise the last error if not response or response.status >= 400: if last_error: raise last_error else: raise A2AConnectionError("All task streaming endpoints failed") - + try: # Process the streaming response buffer = "" current_task = task - + async for chunk in response.content.iter_chunks(): if not chunk: continue - + # Decode chunk - chunk_text = chunk[0].decode('utf-8') + chunk_text = chunk[0].decode("utf-8") buffer += chunk_text - + # Process complete events (separated by double newlines) while "\n\n" in buffer: event, buffer = buffer.split("\n\n", 1) - + # Extract data fields and event type from event event_type = "update" # Default event type event_data = None event_id = None - + for line in event.split("\n"): if line.startswith("event:"): event_type = line[6:].strip() @@ -726,81 +735,93 @@ async def tasks_send_subscribe(self, task: Task) -> AsyncGenerator[Task, None]: event_data = line[5:].strip() elif line.startswith("id:"): event_id = line[3:].strip() - + # Skip if no data if not event_data: continue - + # Try to parse the data as JSON try: data_obj = json.loads(event_data) - + # Handle task updates if event_type == "update" or event_type == "complete": if isinstance(data_obj, dict): # Parse as a Task - current_task = Task.from_dict(data_obj) + task_data = data_obj.get("task", data_obj) + current_task = Task.from_dict(task_data) yield current_task - + # If this is a complete event, we're done - if event_type == "complete" or current_task.status.state in [ - TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED - ]: + if ( + event_type == "complete" + or current_task.status.state + in [ + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + ] + ): return - + # Handle other event types elif event_type == "error": error_msg = data_obj.get("error", "Unknown error") - raise A2AStreamingError(f"Stream error: {error_msg}") - + raise A2AStreamingError( + f"Stream error: {error_msg}" + ) + # Handle raw data (artifact updates, etc.) else: # Update the current task with the new data # This is a simplification; real updates should merge properly if "artifacts" in data_obj: current_task.artifacts = data_obj["artifacts"] - + if "status" in data_obj: - current_task.status = TaskStatus.from_dict(data_obj["status"]) - + current_task.status = TaskStatus.from_dict( + data_obj["status"] + ) + yield current_task - + except json.JSONDecodeError: # Not JSON, create a text update - logger.warning(f"Received non-JSON data in stream: {event_data[:50]}...") + logger.warning( + f"Received non-JSON data in stream: {event_data[:50]}..." + ) # Create a text artifact for backward compatibility - current_task.artifacts.append({ - "parts": [{"type": "text", "text": event_data}] - }) + current_task.artifacts.append( + {"parts": [{"type": "text", "text": event_data}]} + ) yield current_task - + finally: # Ensure we close the response if response: await response.release() - + except Exception as e: if isinstance(e, (A2AConnectionError, A2AResponseError, A2AStreamingError)): raise - + # For any other error, yield a complete task with the error - task.status = TaskStatus( - state=TaskState.FAILED, - message={"error": str(e)} - ) + task.status = TaskStatus(state=TaskState.FAILED, message={"error": str(e)}) yield task - - async def tasks_resubscribe(self, task_id: str, session_id: Optional[str] = None) -> AsyncGenerator[Task, None]: + + async def tasks_resubscribe( + self, task_id: str, session_id: Optional[str] = None + ) -> AsyncGenerator[Task, None]: """ Resubscribe to an existing task's updates. - + Args: task_id: The ID of the task to resubscribe to session_id: Optional session ID (if known) - + Yields: Task updates as they arrive - + Raises: A2AConnectionError: If connection to the agent fails A2AResponseError: If the agent returns an invalid response @@ -811,57 +832,53 @@ async def tasks_resubscribe(self, task_id: str, session_id: Optional[str] = None "aiohttp is required for tasks_resubscribe. " "Install it with 'pip install aiohttp'." ) - + # Check if streaming is supported supports_streaming = await self.check_streaming_support() if not supports_streaming: raise A2AStreamingError("Agent does not support streaming") - + # Real streaming implementation with aiohttp try: # Set up streaming request async with self._create_session() as session: headers = dict(self.headers) # Add headers to request server-sent events - headers['Accept'] = 'text/event-stream' - + headers["Accept"] = "text/event-stream" + # Create JsonRPC request request_data = { "jsonrpc": "2.0", "id": 1, "method": "tasks/resubscribe", - "params": { - "id": task_id - } + "params": {"id": task_id}, } - + # Add session_id if provided if session_id: request_data["params"]["sessionId"] = session_id - + # Try primary endpoint first endpoint_url = f"{self.url}/tasks/stream" response = None try: response = await session.post( - endpoint_url, - json=request_data, - headers=headers + endpoint_url, json=request_data, headers=headers ) - + # Check for HTTP errors if response.status >= 400: # Try alternate endpoint - logger.debug(f"Primary endpoint failed with status {response.status}, trying alternate") + logger.debug( + f"Primary endpoint failed with status {response.status}, trying alternate" + ) if response: await response.release() endpoint_url = f"{self.url}/a2a/tasks/stream" response = await session.post( - endpoint_url, - json=request_data, - headers=headers + endpoint_url, json=request_data, headers=headers ) - + # Check for HTTP errors again if response.status >= 400: error_text = await response.text() @@ -870,45 +887,45 @@ async def tasks_resubscribe(self, task_id: str, session_id: Optional[str] = None ) except Exception as req_error: # Try alternate endpoint if first fails - logger.debug(f"Primary endpoint failed: {req_error}, trying alternate") + logger.debug( + f"Primary endpoint failed: {req_error}, trying alternate" + ) if response: await response.release() endpoint_url = f"{self.url}/a2a/tasks/stream" response = await session.post( - endpoint_url, - json=request_data, - headers=headers + endpoint_url, json=request_data, headers=headers ) - + # Check for HTTP errors if response.status >= 400: error_text = await response.text() raise A2AConnectionError( f"HTTP error {response.status}: {error_text}" ) - + try: # Process the streaming response buffer = "" current_task = None - + async for chunk in response.content.iter_chunks(): if not chunk: continue - + # Decode chunk - chunk_text = chunk[0].decode('utf-8') + chunk_text = chunk[0].decode("utf-8") buffer += chunk_text - + # Process complete events (separated by double newlines) while "\n\n" in buffer: event, buffer = buffer.split("\n\n", 1) - + # Extract data fields and event type from event event_type = "update" # Default event type event_data = None event_id = None - + for line in event.split("\n"): if line.startswith("event:"): event_type = line[6:].strip() @@ -916,103 +933,108 @@ async def tasks_resubscribe(self, task_id: str, session_id: Optional[str] = None event_data = line[5:].strip() elif line.startswith("id:"): event_id = line[3:].strip() - + # Skip if no data if not event_data: continue - + # Try to parse the data as JSON try: data_obj = json.loads(event_data) - + # Handle task updates if event_type == "update" or event_type == "complete": if isinstance(data_obj, dict): # Parse as a Task current_task = Task.from_dict(data_obj) yield current_task - + # If this is a complete event, we're done - if event_type == "complete" or (current_task and current_task.status.state in [ - TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED - ]): + if event_type == "complete" or ( + current_task + and current_task.status.state + in [ + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + ] + ): return - + # Handle other event types elif event_type == "error": error_msg = data_obj.get("error", "Unknown error") - raise A2AStreamingError(f"Stream error: {error_msg}") - + raise A2AStreamingError( + f"Stream error: {error_msg}" + ) + # Handle raw data (artifact updates, etc.) else: # Initialize a task if we don't have one yet if not current_task: current_task = Task( - id=task_id, - session_id=session_id + id=task_id, session_id=session_id ) - + # Update the current task with the new data if "artifacts" in data_obj: current_task.artifacts = data_obj["artifacts"] - + if "status" in data_obj: - current_task.status = TaskStatus.from_dict(data_obj["status"]) - + current_task.status = TaskStatus.from_dict( + data_obj["status"] + ) + yield current_task - + except json.JSONDecodeError: # Not JSON, create a text update - logger.warning(f"Received non-JSON data in stream: {event_data[:50]}...") - + logger.warning( + f"Received non-JSON data in stream: {event_data[:50]}..." + ) + # Initialize a task if we don't have one yet if not current_task: current_task = Task( - id=task_id, - session_id=session_id + id=task_id, session_id=session_id ) - + # Create a text artifact for backward compatibility - current_task.artifacts.append({ - "parts": [{"type": "text", "text": event_data}] - }) + current_task.artifacts.append( + {"parts": [{"type": "text", "text": event_data}]} + ) yield current_task - + finally: # Ensure we close the response if response: await response.release() - + except Exception as e: if isinstance(e, (A2AConnectionError, A2AResponseError, A2AStreamingError)): raise - + # For any other error, yield a task with the error error_task = Task( id=task_id, session_id=session_id, - status=TaskStatus( - state=TaskState.FAILED, - message={"error": str(e)} - ) + status=TaskStatus(state=TaskState.FAILED, message={"error": str(e)}), ) yield error_task - + async def stream_task( - self, - task: Task, - chunk_callback: Optional[Callable[[Dict], None]] = None + self, task: Task, chunk_callback: Optional[Callable[[Dict], None]] = None ) -> AsyncGenerator[Dict, None]: """ Stream the execution of a task. - + Args: task: The task to execute chunk_callback: Optional callback function for each chunk - + Yields: Task status and result chunks - + Raises: A2AConnectionError: If connection to the agent fails A2AResponseError: If the agent returns an invalid response @@ -1023,80 +1045,80 @@ async def stream_task( # Extract status and artifacts for backward compatibility chunk = { "status": task_update.status.state.value, - "artifacts": task_update.artifacts + "artifacts": task_update.artifacts, } - + # Call the callback if provided if chunk_callback: chunk_callback(chunk) - + yield chunk - + # If the task is complete, we're done - if task_update.status.state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED]: + if task_update.status.state in [ + TaskState.COMPLETED, + TaskState.FAILED, + TaskState.CANCELED, + ]: return - + # If we reach here, we've completed the task return - + except (A2AStreamingError, ImportError) as e: # Fall back to legacy implementation - logger.debug(f"Enhanced streaming not supported or failed: {e}. Falling back to legacy implementation.") + logger.debug( + f"Enhanced streaming not supported or failed: {e}. Falling back to legacy implementation." + ) pass - + # Legacy implementation starts here # Check if streaming is supported supports_streaming = await self.check_streaming_support() - + if not supports_streaming: # Fall back to non-streaming if not supported result = await self.send_task(task) - + # Create a single chunk with the complete result - chunk = { - "status": result.status.state.value, - "artifacts": result.artifacts - } - + chunk = {"status": result.status.state.value, "artifacts": result.artifacts} + # Yield the entire response as one chunk if chunk_callback: chunk_callback(chunk) yield chunk return - + if not self._has_aiohttp: # Fall back to non-streaming if aiohttp not available result = await self.send_task(task) - + # Create a single chunk with the complete result - chunk = { - "status": result.status.state.value, - "artifacts": result.artifacts - } - + chunk = {"status": result.status.state.value, "artifacts": result.artifacts} + # Yield the entire response as one chunk if chunk_callback: chunk_callback(chunk) yield chunk return - + # Real streaming implementation with aiohttp try: # Set up streaming request async with self._create_session() as session: headers = dict(self.headers) # Add headers to request server-sent events - headers['Accept'] = 'text/event-stream' - + headers["Accept"] = "text/event-stream" + async with session.post( f"{self.url}/tasks/stream", json={ "jsonrpc": "2.0", "id": 1, "method": "tasks/stream", - "params": task.to_dict() + "params": task.to_dict(), }, - headers=headers + headers=headers, ) as response: # Handle HTTP errors if response.status >= 400: @@ -1108,9 +1130,9 @@ async def stream_task( "jsonrpc": "2.0", "id": 1, "method": "tasks/stream", - "params": task.to_dict() + "params": task.to_dict(), }, - headers=headers + headers=headers, ) as alt_response: # Handle HTTP errors if alt_response.status >= 400: @@ -1118,11 +1140,13 @@ async def stream_task( raise A2AConnectionError( f"HTTP error {alt_response.status}: {error_text}" ) - + # Process the streaming response - async for chunk in self._process_stream(alt_response, chunk_callback): + async for chunk in self._process_stream( + alt_response, chunk_callback + ): yield chunk - + except Exception: error_text = await response.text() raise A2AConnectionError( @@ -1130,28 +1154,27 @@ async def stream_task( ) else: # Process the streaming response from original endpoint - async for chunk in self._process_stream(response, chunk_callback): + async for chunk in self._process_stream( + response, chunk_callback + ): yield chunk - + except Exception as e: if isinstance(e, (A2AConnectionError, A2AResponseError)): raise - + # Fall back to non-streaming for other errors logger.warning(f"Error in streaming, falling back to non-streaming: {e}") result = await self.send_task(task) - + # Create a single chunk with the complete result - chunk = { - "status": result.status.state.value, - "artifacts": result.artifacts - } - + chunk = {"status": result.status.state.value, "artifacts": result.artifacts} + # Yield the entire response as one chunk if chunk_callback: chunk_callback(chunk) yield chunk - + async def _process_stream(self, response, chunk_callback=None): """Process a streaming response using enhanced parsing.""" try: @@ -1159,51 +1182,53 @@ async def _process_stream(self, response, chunk_callback=None): last_event_type = None chunks_received = 0 bytes_received = 0 - + # Debug logging logger.debug(f"Starting to process streaming response") logger.debug(f"Response headers: {response.headers}") - + async for chunk in response.content.iter_chunks(): if not chunk: continue - + # Update metrics chunks_received += 1 bytes_received += len(chunk[0]) - + # Decode chunk - chunk_text = chunk[0].decode('utf-8') + chunk_text = chunk[0].decode("utf-8") buffer += chunk_text - + # Debug every 10 chunks if chunks_received % 10 == 0: - logger.debug(f"Processed {chunks_received} chunks, {bytes_received} bytes") - + logger.debug( + f"Processed {chunks_received} chunks, {bytes_received} bytes" + ) + # Detailed debug for first few chunks if chunks_received <= 3: logger.debug(f"Raw chunk {chunks_received}: {chunk_text}") - + # Process complete events (separated by double newlines) while "\n\n" in buffer: event, buffer = buffer.split("\n\n", 1) - + # Skip comments (lines starting with colon) - if event.startswith(':'): + if event.startswith(":"): logger.debug(f"Skipping SSE comment: {event}") continue - + # Extract fields from the event event_type = None event_data = None event_id = None retry_time = None - + for line in event.split("\n"): line = line.strip() if not line: continue - + if line.startswith("event:"): event_type = line[6:].strip() logger.debug(f"Found event type: {event_type}") @@ -1219,18 +1244,18 @@ async def _process_stream(self, response, chunk_callback=None): logger.debug(f"Found retry time: {retry_time}") except ValueError: pass - + # Default to "message" event type if none provided if not event_type: event_type = last_event_type or "message" - + last_event_type = event_type - + # Handle connected event if event_type == "connected": logger.info("Received connected event from server") continue - + # Handle error events if event_type == "error": if event_data: @@ -1238,37 +1263,43 @@ async def _process_stream(self, response, chunk_callback=None): error_data = json.loads(event_data) logger.error(f"Received error event: {error_data}") # Raise exception to be caught by the outer handler - raise A2AStreamingError(error_data.get("error", "Unknown streaming error")) + raise A2AStreamingError( + error_data.get("error", "Unknown streaming error") + ) except json.JSONDecodeError: # Bad JSON in error, use raw text - logger.error(f"Received malformed error event: {event_data}") + logger.error( + f"Received malformed error event: {event_data}" + ) raise A2AStreamingError(f"Stream error: {event_data}") continue - + # Skip if no data if not event_data: logger.warning("Empty event data, skipping") continue - + # Try to parse as JSON try: data_obj = json.loads(event_data) - logger.debug(f"Successfully parsed JSON data: {str(data_obj)[:50]}...") - + logger.debug( + f"Successfully parsed JSON data: {str(data_obj)[:50]}..." + ) + # Handle structured events with the new StreamingChunk class if isinstance(data_obj, dict): streaming_chunk = StreamingChunk.from_dict(data_obj) - + # If lastChunk is set, this is the final chunk is_last = streaming_chunk.is_last - + # Process with callback if provided if chunk_callback: chunk_callback(data_obj) - + # Yield the chunk yield data_obj - + # If this is the last chunk, we're done if is_last: logger.info("Received last chunk, ending stream") @@ -1281,28 +1312,30 @@ async def _process_stream(self, response, chunk_callback=None): yield data_obj except json.JSONDecodeError: # Not JSON, create a text chunk - logger.warning(f"Failed to parse JSON, treating as text: {event_data[:50]}...") + logger.warning( + f"Failed to parse JSON, treating as text: {event_data[:50]}..." + ) text_chunk = {"type": "text", "text": event_data} if chunk_callback: chunk_callback(text_chunk) yield text_chunk - + # Log completion logger.info(f"Stream completed, processed {chunks_received} raw chunks") - + except Exception as e: logger.error(f"Error processing streaming response: {e}") # Include stack trace for debugging logger.debug(traceback.format_exc()) raise - + def _extract_json_from_response(self, text): """ Extract JSON data from a response that might be HTML or contain embedded JSON. - + Args: text: The response text to parse - + Returns: Extracted JSON data as dictionary """ @@ -1311,18 +1344,18 @@ def _extract_json_from_response(self, text): return json.loads(text) except json.JSONDecodeError: pass - + # Try to extract JSON from HTML (look for JSON in a
 tag or