diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index c5d7195585..1ce2216fe7 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -581,7 +581,7 @@ def subscribe_to_query_processing(self, query_observable: Observable) -> Disposa Returns: Disposable: A disposable representing the subscription. """ - print_emission_args = {"enabled": True, "dev_name": self.dev_name, "counts": {}} + print_emission_args = {"enabled": False, "dev_name": self.dev_name, "counts": {}} def _process_query(query) -> Observable: """ diff --git a/dimos/agents/cerebras_agent.py b/dimos/agents/cerebras_agent.py index 8aebf0f509..854beb848d 100644 --- a/dimos/agents/cerebras_agent.py +++ b/dimos/agents/cerebras_agent.py @@ -26,6 +26,8 @@ from typing import Any, Dict, List, Optional, Union, Tuple import logging import json +import re +import time from cerebras.cloud.sdk import Cerebras from dotenv import load_dotenv @@ -33,14 +35,16 @@ from reactivex import Observable from reactivex.observer import Observer from reactivex.scheduler import ThreadPoolScheduler -from openai._types import NOT_GIVEN # Local imports from dimos.agents.agent import LLMAgent from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.agents.tokenizer.base import AbstractTokenizer from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.stream.frame_processor import FrameProcessor from dimos.utils.logging_config import setup_logger +from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer # Initialize environment variables load_dotenv() @@ -49,6 +53,56 @@ logger = setup_logger("dimos.agents.cerebras") +# Response object compatible with LLMAgent +class CerebrasResponseMessage(dict): + def __init__( + self, + content="", + tool_calls=None, + ): + self.content = content + self.tool_calls = tool_calls or [] + self.parsed = None + + # Initialize as dict with the proper structure + super().__init__(self.to_dict()) + + def __str__(self): + # Return a string representation for logging + if self.content: + return self.content + elif self.tool_calls: + # Return JSON representation of the first tool call + if self.tool_calls: + tool_call = self.tool_calls[0] + tool_json = { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + return json.dumps(tool_json) + return "[No content]" + + def to_dict(self): + """Convert to dictionary format for JSON serialization.""" + result = {"role": "assistant", "content": self.content or ""} + + if self.tool_calls: + result["tool_calls"] = [] + for tool_call in self.tool_calls: + result["tool_calls"].append( + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + ) + + return result + + class CerebrasAgent(LLMAgent): """Cerebras agent implementation using the official Cerebras Python SDK. @@ -76,6 +130,8 @@ def __init__( image_detail: str = "low", pool_scheduler: Optional[ThreadPoolScheduler] = None, process_all_inputs: Optional[bool] = None, + tokenizer: Optional[AbstractTokenizer] = None, + prompt_builder: Optional[PromptBuilder] = None, ): """ Initializes a new instance of the CerebrasAgent. @@ -104,6 +160,8 @@ def __init__( image_detail (str): Detail level for images ("low", "high", "auto"). pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. process_all_inputs (bool): Whether to process all inputs or skip when busy. + tokenizer (AbstractTokenizer): The tokenizer for the agent. + prompt_builder (PromptBuilder): The prompt builder for the agent. """ # Determine appropriate default for process_all_inputs if not provided if process_all_inputs is None: @@ -154,10 +212,21 @@ def __init__( self.image_detail = image_detail self.max_output_tokens_per_request = max_output_tokens_per_request self.max_input_tokens_per_request = max_input_tokens_per_request + self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request # Add static context to memory. self._add_context_to_memory() + # Initialize tokenizer and prompt builder + self.tokenizer = tokenizer or OpenAITokenizer( + model_name="gpt-4o" + ) # Use GPT-4 tokenizer for better accuracy + self.prompt_builder = prompt_builder or PromptBuilder( + model_name=self.model_name, + max_tokens=self.max_input_tokens_per_request, + tokenizer=self.tokenizer, + ) + logger.info("Cerebras Agent Initialized.") def _add_context_to_memory(self): @@ -204,7 +273,6 @@ def _build_prompt( Returns: list: Messages formatted for Cerebras API. """ - # Add system message if provided and not already in history if self.system_query and (not messages or messages[0].get("role") != "system"): messages.insert(0, {"role": "system", "content": self.system_query}) @@ -239,19 +307,155 @@ def _build_prompt( logger.info(f"Added {len(images)} image(s) to conversation") + # Use new truncation function + messages = self._truncate_messages(messages, override_token_limit) + return messages - def _send_query(self, messages: list) -> Any: + def _truncate_messages(self, messages: list, override_token_limit: bool = False) -> list: + """Truncate messages if total tokens exceed 16k using existing truncate_tokens method. + + Args: + messages (list): List of message dictionaries + override_token_limit (bool): Whether to skip truncation + + Returns: + list: Messages with content truncated if needed + """ + if override_token_limit: + return messages + + total_tokens = 0 + for message in messages: + if isinstance(message.get("content"), str): + total_tokens += self.prompt_builder.tokenizer.token_count(message["content"]) + elif isinstance(message.get("content"), list): + for item in message["content"]: + if item.get("type") == "text": + total_tokens += self.prompt_builder.tokenizer.token_count(item["text"]) + elif item.get("type") == "image_url": + total_tokens += 85 + + if total_tokens > 16000: + excess_tokens = total_tokens - 16000 + current_tokens = total_tokens + + # Start from oldest messages and truncate until under 16k + for i in range(len(messages)): + if current_tokens <= 16000: + break + + msg = messages[i] + if msg.get("role") == "system": + continue + + if isinstance(msg.get("content"), str): + original_tokens = self.prompt_builder.tokenizer.token_count(msg["content"]) + # Calculate how much to truncate from this message + tokens_to_remove = min(excess_tokens, original_tokens // 3) + new_max_tokens = max(50, original_tokens - tokens_to_remove) + + msg["content"] = self.prompt_builder.truncate_tokens( + msg["content"], new_max_tokens, "truncate_end" + ) + + new_tokens = self.prompt_builder.tokenizer.token_count(msg["content"]) + tokens_saved = original_tokens - new_tokens + current_tokens -= tokens_saved + excess_tokens -= tokens_saved + + logger.info( + f"Truncated older messages using truncate_tokens, final tokens: {current_tokens}" + ) + else: + logger.info(f"No truncation needed, total tokens: {total_tokens}") + + return messages + + def clean_cerebras_schema(self, schema: dict) -> dict: + """Simple schema cleaner that removes unsupported fields for Cerebras API.""" + if not isinstance(schema, dict): + return schema + + # Removing the problematic fields that pydantic generates + cleaned = {} + unsupported_fields = { + "minItems", + "maxItems", + "uniqueItems", + "exclusiveMinimum", + "exclusiveMaximum", + "minimum", + "maximum", + } + + for key, value in schema.items(): + if key in unsupported_fields: + continue # Skip unsupported fields + elif isinstance(value, dict): + cleaned[key] = self.clean_cerebras_schema(value) + elif isinstance(value, list): + cleaned[key] = [ + self.clean_cerebras_schema(item) if isinstance(item, dict) else item + for item in value + ] + else: + cleaned[key] = value + + return cleaned + + def create_tool_call( + self, name: str = None, arguments: dict = None, call_id: str = None, content: str = None + ): + """Create a tool call object from either direct parameters or JSON content.""" + # If content is provided, parse it as JSON + if content: + logger.info(f"Creating tool call from content: {content}") + try: + content_json = json.loads(content) + if ( + isinstance(content_json, dict) + and "name" in content_json + and "arguments" in content_json + ): + name = content_json["name"] + arguments = content_json["arguments"] + else: + return None + except json.JSONDecodeError: + logger.warning("Content appears to be JSON but failed to parse") + return None + + # Create the tool call object + if name and arguments is not None: + timestamp = int(time.time() * 1000000) # microsecond precision + tool_id = f"call_{timestamp}" + + logger.info(f"Creating tool call with timestamp ID: {tool_id}") + return type( + "ToolCall", + (), + { + "id": tool_id, + "function": type( + "Function", (), {"name": name, "arguments": json.dumps(arguments)} + ), + }, + ) + + return None + + def _send_query(self, messages: list) -> CerebrasResponseMessage: """Sends the query to Cerebras API using the official Cerebras SDK. Args: messages (list): The prompt messages to send. Returns: - The response message from Cerebras. + The response message from Cerebras wrapped in our CerebrasResponseMessage class. Raises: - Exception: If no response message is returned. + Exception: If no response message is returned from the API. ConnectionError: If there's an issue connecting to the API. ValueError: If the messages or other parameters are invalid. """ @@ -260,46 +464,46 @@ def _send_query(self, messages: list) -> Any: api_params = { "model": self.model_name, "messages": messages, - "max_tokens": self.max_output_tokens_per_request, + # "max_tokens": self.max_output_tokens_per_request, } # Add tools if available if self.skill_library and self.skill_library.get_tools(): tools = self.skill_library.get_tools() - api_params["tools"] = tools # No conversion needed + for tool in tools: + if "function" in tool and "parameters" in tool["function"]: + tool["function"]["parameters"] = self.clean_cerebras_schema( + tool["function"]["parameters"] + ) + api_params["tools"] = tools api_params["tool_choice"] = "auto" - # Add response format for structured output if specified if self.response_model is not None: - # Convert Pydantic model to JSON schema for Cerebras - from pydantic import TypeAdapter - - schema = TypeAdapter(self.response_model).json_schema() - - # Ensure additionalProperties is set to False for strict mode - if "additionalProperties" not in schema: - schema["additionalProperties"] = False - api_params["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": self.response_model.__name__ - if hasattr(self.response_model, "__name__") - else "response", - "strict": True, - "schema": schema, - }, + "type": "json_object", + "schema": self.response_model, } # Make the API call response = self.client.chat.completions.create(**api_params) - response_message = response.choices[0].message - if response_message is None: + raw_message = response.choices[0].message + if raw_message is None: logger.error("Response message does not exist.") raise Exception("Response message does not exist.") - return response_message + # Process response into final format + content = raw_message.content + tool_calls = getattr(raw_message, "tool_calls", None) + + # If no structured tool calls from API, try parsing content as JSON tool call + if not tool_calls and content and content.strip().startswith("{"): + parsed_tool_call = self.create_tool_call(content=content) + if parsed_tool_call: + tool_calls = [parsed_tool_call] + content = None + + return CerebrasResponseMessage(content=content, tool_calls=tool_calls) except ConnectionError as ce: logger.error(f"Connection error with Cerebras API: {ce}") @@ -308,6 +512,8 @@ def _send_query(self, messages: list) -> Any: logger.error(f"Invalid parameters for Cerebras API: {ve}") raise except Exception as e: + # Print the raw API parameters when an error occurs + logger.error(f"Raw API parameters: {json.dumps(api_params, indent=2)}") logger.error(f"Unexpected error in Cerebras API call: {e}") raise @@ -340,7 +546,6 @@ def _observable_query( # Create a local copy of conversation history and record its length messages = copy.deepcopy(self.conversation_history) - base_len = len(messages) # Update query and get context self._update_query(incoming_query) @@ -351,94 +556,53 @@ def _observable_query( messages, base64_image, dimensions, override_token_limit, condensed_results ) - # Send query and get response - logger.info("Sending Query.") - response_message = self._send_query(messages) - logger.info(f"Received Response: {response_message}") - - if response_message is None: - logger.error("Received None response from Cerebras API") - observer.on_next("") - observer.on_completed() - return - - # Add assistant response to local messages (always) - assistant_message = {"role": "assistant"} - - if response_message.content: - assistant_message["content"] = response_message.content - else: - assistant_message["content"] = "" # Ensure content is never None - - if hasattr(response_message, "tool_calls") and response_message.tool_calls: - assistant_message["tool_calls"] = [] - for tool_call in response_message.tool_calls: - assistant_message["tool_calls"].append( - { - "id": tool_call.id, - "type": "function", - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - } + while True: + logger.info("Sending Query.") + response_message = self._send_query(messages) + logger.info(f"Received Response: {response_message}") + + if response_message is None: + raise Exception("Response message does not exist.") + + # If no skill library or no tool calls, we're done + if ( + self.skill_library is None + or self.skill_library.get_tools() is None + or response_message.tool_calls is None + ): + final_msg = ( + response_message.parsed + if hasattr(response_message, "parsed") and response_message.parsed + else ( + response_message.content + if hasattr(response_message, "content") + else response_message + ) ) - logger.info( - f"Assistant response includes {len(response_message.tool_calls)} tool call(s)" - ) + messages.append(response_message) + break - messages.append(assistant_message) + logger.info(f"Assistant requested {len(response_message.tool_calls)} tool call(s)") + next_response = self._handle_tooling(response_message, messages) - # Handle tool calls if present (add tool messages to conversation) - self._handle_tooling(response_message, messages) + if next_response is None: + final_msg = response_message.content or "" + break + + response_message = next_response - # At the end, append only new messages to the global conversation history under a lock - if not hasattr(self, "_history_lock"): - self._history_lock = threading.Lock() with self._history_lock: - for msg in messages[base_len:]: - self.conversation_history.append(msg) + self.conversation_history = messages logger.info( f"Updated conversation history (total: {len(self.conversation_history)} messages)" ) - # Send response to observers - result = response_message.content or "" - observer.on_next(result) - self.response_subject.on_next(result) + # Emit the final message content to the observer + observer.on_next(final_msg) + self.response_subject.on_next(final_msg) observer.on_completed() except Exception as e: logger.error(f"Query failed in {self.dev_name}: {e}") observer.on_error(e) self.response_subject.on_error(e) - - def _handle_tooling(self, response_message, messages): - """Executes tools and appends tool-use/result blocks to messages.""" - if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: - logger.info("No tool calls found in response message") - return None - - if len(response_message.tool_calls) > 1: - logger.warning( - "Multiple tool calls detected in response message. Not a tested feature." - ) - - # Execute all tools and add their results to messages - for tool_call in response_message.tool_calls: - logger.info(f"Processing tool call: {tool_call.function.name}") - - # Execute the tool - args = json.loads(tool_call.function.arguments) - tool_result = self.skill_library.call(tool_call.function.name, **args) - logger.info(f"Function Call Results: {tool_result}") - - # Add tool result to conversation history (OpenAI format) - messages.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "content": str(tool_result), - "name": tool_call.function.name, - } - ) diff --git a/dimos/skills/observe_stream.py b/dimos/skills/observe_stream.py index dc38052508..4449de2995 100644 --- a/dimos/skills/observe_stream.py +++ b/dimos/skills/observe_stream.py @@ -203,7 +203,6 @@ def _process_frame(self, frame): observable = self._agent.run_observable_query( f"{self.query_text}\n\nHere is the current camera view from the robot:", base64_image=base64_image, - thinking_budget_tokens=0, ) # Simple subscription to make sure the query executes diff --git a/tests/test_cerebras_agent_query.py b/tests/test_cerebras_agent_query.py new file mode 100644 index 0000000000..5ed4007eed --- /dev/null +++ b/tests/test_cerebras_agent_query.py @@ -0,0 +1,29 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import test_header + +from dotenv import load_dotenv +from dimos.agents.cerebras_agent import CerebrasAgent + +# Load API key from environment +load_dotenv() + +# Create a CerebrasAgent instance +agent = CerebrasAgent(dev_name="test_agent", query="What is the capital of France?") + +# Use the stream_query method to get a response +response = agent.run_observable_query("What is the capital of France?").run() + +print(f"Response from Cerebras Agent: {response}") diff --git a/tests/test_cerebras_unitree_ros.py b/tests/test_cerebras_unitree_ros.py new file mode 100644 index 0000000000..cbb7c130db --- /dev/null +++ b/tests/test_cerebras_unitree_ros.py @@ -0,0 +1,118 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from dimos.robot.robot import MockRobot +import tests.test_header + +import time +from dotenv import load_dotenv +from dimos.agents.cerebras_agent import CerebrasAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal +from dimos.skills.visual_navigation_skills import FollowHuman +import reactivex as rx +import reactivex.operators as ops +from dimos.stream.audio.pipelines import tts, stt +from dimos.web.websocket_vis.server import WebsocketVis +import threading +from dimos.types.vector import Vector +from dimos.skills.speak import Speak + +# Load API key from environment +load_dotenv() + +# robot = MockRobot() +robot_skills = MyUnitreeSkills() + +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=robot_skills, + mock_connection=False, + new_memory=True, +) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) + +streams = { + "unitree_video": robot.get_ros_video_stream(), +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface( + port=5555, + text_streams=text_streams, + **streams, +) + +stt_node = stt() + +# Create a CerebrasAgent instance +agent = CerebrasAgent( + dev_name="test_cerebras_agent", + input_query_stream=stt_node.emit_text(), + # input_query_stream=web_interface.query_stream, + skills=robot_skills, + system_query="""You are an agent controlling a virtual robot. When given a query, respond by using the appropriate tool calls if needed to execute commands on the robot. + +IMPORTANT INSTRUCTIONS: +1. Each tool call must include the exact function name and appropriate parameters +2. If a function needs parameters like 'distance' or 'angle', be sure to include them +3. If you're unsure which tool to use, choose the most appropriate one based on the user's query +4. Parse the user's instructions carefully to determine correct parameter values + +When you need to call a skill or tool, ALWAYS respond ONLY with a JSON object in this exact format: {"name": "SkillName", "arguments": {"arg1": "value1", "arg2": "value2"}} + +Example: If the user asks to spin right by 90 degrees, output ONLY the following: {"name": "SpinRight", "arguments": {"degrees": 90}}""", + model_name="llama-4-scout-17b-16e-instruct", +) + +tts_node = tts() +tts_node.consume_text(agent.get_response_observable()) + +robot_skills.add(ObserveStream) +robot_skills.add(KillSkill) +robot_skills.add(NavigateWithText) +robot_skills.add(FollowHuman) +robot_skills.add(GetPose) +robot_skills.add(Speak) +robot_skills.add(NavigateToGoal) +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("NavigateWithText", robot=robot) +robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +robot_skills.create_instance("NavigateToGoal", robot=robot) + + +robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +# print(f"Registered skills: {', '.join([skill.__name__ for skill in robot_skills.skills])}") +print("Cerebras agent demo initialized. You can now interact with the agent via the web interface.") + +web_interface.run()