diff --git a/dimos/agents/cerebras_agent.py b/dimos/agents/cerebras_agent.py new file mode 100644 index 0000000000..8aebf0f509 --- /dev/null +++ b/dimos/agents/cerebras_agent.py @@ -0,0 +1,444 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cerebras agent implementation for the DIMOS agent framework. + +This module provides a CerebrasAgent class that implements the LLMAgent interface +for Cerebras inference API using the official Cerebras Python SDK. +""" + +from __future__ import annotations + +import os +import threading +import copy +from typing import Any, Dict, List, Optional, Union, Tuple +import logging +import json + +from cerebras.cloud.sdk import Cerebras +from dotenv import load_dotenv +from pydantic import BaseModel +from reactivex import Observable +from reactivex.observer import Observer +from reactivex.scheduler import ThreadPoolScheduler +from openai._types import NOT_GIVEN + +# Local imports +from dimos.agents.agent import LLMAgent +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.stream.frame_processor import FrameProcessor +from dimos.utils.logging_config import setup_logger + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the Cerebras agent +logger = setup_logger("dimos.agents.cerebras") + + +class CerebrasAgent(LLMAgent): + """Cerebras agent implementation using the official Cerebras Python SDK. + + This class implements the _send_query method to interact with Cerebras API + using their official SDK, allowing most of the LLMAgent logic to be reused. + """ + + def __init__( + self, + dev_name: str, + agent_type: str = "Vision", + query: str = "What do you see?", + input_query_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + input_data_stream: Optional[Observable] = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + system_query: Optional[str] = None, + max_input_tokens_per_request: int = 128000, + max_output_tokens_per_request: int = 16384, + model_name: str = "llama-4-scout-17b-16e-instruct", + skills: Optional[Union[AbstractSkill, list[AbstractSkill], SkillLibrary]] = None, + response_model: Optional[BaseModel] = None, + frame_processor: Optional[FrameProcessor] = None, + image_detail: str = "low", + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: Optional[bool] = None, + ): + """ + Initializes a new instance of the CerebrasAgent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + query (str): The default query text. + input_query_stream (Observable): An observable for query input. + input_video_stream (Observable): An observable for video frames. + input_data_stream (Observable): An observable for data input. + output_dir (str): Directory for output files. + agent_memory (AbstractAgentSemanticMemory): The memory system. + system_query (str): The system prompt to use with RAG context. + max_input_tokens_per_request (int): Maximum tokens for input. + max_output_tokens_per_request (int): Maximum tokens for output. + model_name (str): The Cerebras model name to use. Available options: + - llama-4-scout-17b-16e-instruct (default, fastest) + - llama3.1-8b + - llama-3.3-70b + - qwen-3-32b + - deepseek-r1-distill-llama-70b (private preview) + skills (Union[AbstractSkill, List[AbstractSkill], SkillLibrary]): Skills available to the agent. + response_model (BaseModel): Optional Pydantic model for structured responses. + frame_processor (FrameProcessor): Custom frame processor. + image_detail (str): Detail level for images ("low", "high", "auto"). + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + process_all_inputs (bool): Whether to process all inputs or skip when busy. + """ + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + # Default to True for text queries, False for video streams + if input_query_stream is not None and input_video_stream is None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + input_query_stream=input_query_stream, + input_video_stream=input_video_stream, + input_data_stream=input_data_stream, + ) + + # Initialize Cerebras client + self.client = Cerebras() + + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + # Initialize conversation history for multi-turn conversations + self.conversation_history = [] + self._history_lock = threading.Lock() + + # Configure skills + self.skills = skills + self.skill_library = None + if isinstance(self.skills, SkillLibrary): + self.skill_library = self.skills + elif isinstance(self.skills, list): + self.skill_library = SkillLibrary() + for skill in self.skills: + self.skill_library.add(skill) + elif isinstance(self.skills, AbstractSkill): + self.skill_library = SkillLibrary() + self.skill_library.add(self.skills) + + self.response_model = response_model + self.model_name = model_name + self.image_detail = image_detail + self.max_output_tokens_per_request = max_output_tokens_per_request + self.max_input_tokens_per_request = max_input_tokens_per_request + + # Add static context to memory. + self._add_context_to_memory() + + logger.info("Cerebras Agent Initialized.") + + def _add_context_to_memory(self): + """Adds initial context to the agent's memory.""" + context_data = [ + ( + "id0", + "Optical Flow is a technique used to track the movement of objects in a video sequence.", + ), + ( + "id1", + "Edge Detection is a technique used to identify the boundaries of objects in an image.", + ), + ("id2", "Video is a sequence of frames captured at regular intervals."), + ( + "id3", + "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", + ), + ( + "id4", + "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", + ), + ] + for doc_id, text in context_data: + self.agent_memory.add_vector(doc_id, text) + + def _build_prompt( + self, + messages: list, + base64_image: Optional[Union[str, List[str]]] = None, + dimensions: Optional[Tuple[int, int]] = None, + override_token_limit: bool = False, + condensed_results: str = "", + ) -> list: + """Builds a prompt message specifically for Cerebras API. + + Args: + messages (list): Existing messages list to build upon. + base64_image (Union[str, List[str]]): Optional Base64-encoded image(s). + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + condensed_results (str): The condensed RAG context. + + Returns: + list: Messages formatted for Cerebras API. + """ + + # Add system message if provided and not already in history + if self.system_query and (not messages or messages[0].get("role") != "system"): + messages.insert(0, {"role": "system", "content": self.system_query}) + logger.info("Added system message to conversation") + + # Append user query while handling RAG + if condensed_results: + user_message = {"role": "user", "content": f"{condensed_results}\n\n{self.query}"} + logger.info("Created user message with RAG context") + else: + user_message = {"role": "user", "content": self.query} + + messages.append(user_message) + + if base64_image is not None: + # Handle both single image (str) and multiple images (List[str]) + images = [base64_image] if isinstance(base64_image, str) else base64_image + + # For Cerebras, we'll add images inline with text (OpenAI-style format) + for img in images: + img_content = [ + {"type": "text", "text": "Here is an image to analyze:"}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{img}", + "detail": self.image_detail, + }, + }, + ] + messages.append({"role": "user", "content": img_content}) + + logger.info(f"Added {len(images)} image(s) to conversation") + + return messages + + def _send_query(self, messages: list) -> Any: + """Sends the query to Cerebras API using the official Cerebras SDK. + + Args: + messages (list): The prompt messages to send. + + Returns: + The response message from Cerebras. + + Raises: + Exception: If no response message is returned. + ConnectionError: If there's an issue connecting to the API. + ValueError: If the messages or other parameters are invalid. + """ + try: + # Prepare API call parameters + api_params = { + "model": self.model_name, + "messages": messages, + "max_tokens": self.max_output_tokens_per_request, + } + + # Add tools if available + if self.skill_library and self.skill_library.get_tools(): + tools = self.skill_library.get_tools() + api_params["tools"] = tools # No conversion needed + api_params["tool_choice"] = "auto" + + # Add response format for structured output if specified + if self.response_model is not None: + # Convert Pydantic model to JSON schema for Cerebras + from pydantic import TypeAdapter + + schema = TypeAdapter(self.response_model).json_schema() + + # Ensure additionalProperties is set to False for strict mode + if "additionalProperties" not in schema: + schema["additionalProperties"] = False + + api_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": self.response_model.__name__ + if hasattr(self.response_model, "__name__") + else "response", + "strict": True, + "schema": schema, + }, + } + + # Make the API call + response = self.client.chat.completions.create(**api_params) + + response_message = response.choices[0].message + if response_message is None: + logger.error("Response message does not exist.") + raise Exception("Response message does not exist.") + + return response_message + + except ConnectionError as ce: + logger.error(f"Connection error with Cerebras API: {ce}") + raise + except ValueError as ve: + logger.error(f"Invalid parameters for Cerebras API: {ve}") + raise + except Exception as e: + logger.error(f"Unexpected error in Cerebras API call: {e}") + raise + + def _observable_query( + self, + observer: Observer, + base64_image: Optional[str] = None, + dimensions: Optional[Tuple[int, int]] = None, + override_token_limit: bool = False, + incoming_query: Optional[str] = None, + reset_conversation: bool = False, + ): + """Main query handler that manages conversation history and Cerebras interactions. + + This method follows ClaudeAgent's pattern for efficient conversation history management. + + Args: + observer (Observer): The observer to emit responses to. + base64_image (str): Optional Base64-encoded image. + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + incoming_query (str): Optional query to update the agent's query. + reset_conversation (bool): Whether to reset the conversation history. + """ + try: + # Reset conversation history if requested + if reset_conversation: + self.conversation_history = [] + logger.info("Conversation history reset") + + # Create a local copy of conversation history and record its length + messages = copy.deepcopy(self.conversation_history) + base_len = len(messages) + + # Update query and get context + self._update_query(incoming_query) + _, condensed_results = self._get_rag_context() + + # Build prompt + messages = self._build_prompt( + messages, base64_image, dimensions, override_token_limit, condensed_results + ) + + # Send query and get response + logger.info("Sending Query.") + response_message = self._send_query(messages) + logger.info(f"Received Response: {response_message}") + + if response_message is None: + logger.error("Received None response from Cerebras API") + observer.on_next("") + observer.on_completed() + return + + # Add assistant response to local messages (always) + assistant_message = {"role": "assistant"} + + if response_message.content: + assistant_message["content"] = response_message.content + else: + assistant_message["content"] = "" # Ensure content is never None + + if hasattr(response_message, "tool_calls") and response_message.tool_calls: + assistant_message["tool_calls"] = [] + for tool_call in response_message.tool_calls: + assistant_message["tool_calls"].append( + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + ) + logger.info( + f"Assistant response includes {len(response_message.tool_calls)} tool call(s)" + ) + + messages.append(assistant_message) + + # Handle tool calls if present (add tool messages to conversation) + self._handle_tooling(response_message, messages) + + # At the end, append only new messages to the global conversation history under a lock + if not hasattr(self, "_history_lock"): + self._history_lock = threading.Lock() + with self._history_lock: + for msg in messages[base_len:]: + self.conversation_history.append(msg) + logger.info( + f"Updated conversation history (total: {len(self.conversation_history)} messages)" + ) + + # Send response to observers + result = response_message.content or "" + observer.on_next(result) + self.response_subject.on_next(result) + observer.on_completed() + + except Exception as e: + logger.error(f"Query failed in {self.dev_name}: {e}") + observer.on_error(e) + self.response_subject.on_error(e) + + def _handle_tooling(self, response_message, messages): + """Executes tools and appends tool-use/result blocks to messages.""" + if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: + logger.info("No tool calls found in response message") + return None + + if len(response_message.tool_calls) > 1: + logger.warning( + "Multiple tool calls detected in response message. Not a tested feature." + ) + + # Execute all tools and add their results to messages + for tool_call in response_message.tool_calls: + logger.info(f"Processing tool call: {tool_call.function.name}") + + # Execute the tool + args = json.loads(tool_call.function.arguments) + tool_result = self.skill_library.call(tool_call.function.name, **args) + logger.info(f"Function Call Results: {tool_result}") + + # Add tool result to conversation history (OpenAI format) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": str(tool_result), + "name": tool_call.function.name, + } + ) diff --git a/dimos/robot/unitree/unitree_go2.py b/dimos/robot/unitree/unitree_go2.py index 58447ba9e0..7f1d760b34 100644 --- a/dimos/robot/unitree/unitree_go2.py +++ b/dimos/robot/unitree/unitree_go2.py @@ -21,9 +21,7 @@ from dimos.stream.video_providers.unitree import UnitreeVideoProvider from reactivex.disposable import CompositeDisposable import logging -from dimos.robot.unitree.external.go2_webrtc_connect.go2_webrtc_driver.webrtc_driver import ( - WebRTCConnectionMethod, -) +from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod import os from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from reactivex.scheduler import ThreadPoolScheduler diff --git a/dimos/robot/unitree/unitree_skills.py b/dimos/robot/unitree/unitree_skills.py index 197d7a14fd..38adc399c8 100644 --- a/dimos/robot/unitree/unitree_skills.py +++ b/dimos/robot/unitree/unitree_skills.py @@ -261,122 +261,29 @@ def __call__(self): # region Class-based Skills class Move(AbstractRobotSkill): - """Move the robot using direct velocity commands. - - This skill works with both ROS and WebRTC robot implementations. - """ + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" x: float = Field(..., description="Forward velocity (m/s).") y: float = Field(default=0.0, description="Left/right velocity (m/s)") yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field( - default=0.0, description="How long to move (seconds). If 0, command is continuous" - ) + duration: float = Field(default=0.0, description="How long to move (seconds).") def __call__(self): super().__call__() - - from dimos.types.vector import Vector - - vector = Vector(self.x, self.y, self.yaw) - - # Handle duration for continuous movement - if self.duration > 0: - import time - import threading - import asyncio - - # Create a stop event - stop_event = threading.Event() - - # Function to continuously send movement commands - async def continuous_move(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - start_time = time.time() - try: - while ( - not stop_event.is_set() and (time.time() - start_time) < self.duration - ): - self._robot.move(vector) - await asyncio.sleep(0.001) # Send commands at 1000Hz - # Always stop at the end - self._robot.move(Vector(0, 0, 0)) - finally: - loop.close() - - # Run movement in a separate thread with asyncio event loop - move_thread = threading.Thread(target=lambda: asyncio.run(continuous_move())) - move_thread.daemon = True - move_thread.start() - - # Wait for the full duration - time.sleep(self.duration) - stop_event.set() - move_thread.join(timeout=0.5) # Wait for thread to finish with timeout - else: - # Just execute the move command once for continuous movement - self._robot.move(vector) - return True + return self._robot.move_vel(x=self.x, y=self.y, yaw=self.yaw, duration=self.duration) class Reverse(AbstractRobotSkill): - """Reverse the robot using direct velocity commands. - - This skill works with both ROS and WebRTC robot implementations. - """ + """Reverse the robot using direct velocity commands. Determine duration required based on user distance instructions.""" x: float = Field(..., description="Backward velocity (m/s). Positive values move backward.") y: float = Field(default=0.0, description="Left/right velocity (m/s)") yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field( - default=0.0, description="How long to move (seconds). If 0, command is continuous" - ) + duration: float = Field(default=0.0, description="How long to move (seconds).") def __call__(self): super().__call__() - from dimos.types.vector import Vector - - # Use negative x for backward movement - vector = Vector(-self.x, self.y, self.yaw) - - # Handle duration for continuous movement - if self.duration > 0: - import time - import threading - import asyncio - - # Create a stop event - stop_event = threading.Event() - - # Function to continuously send movement commands - async def continuous_move(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - start_time = time.time() - try: - while ( - not stop_event.is_set() and (time.time() - start_time) < self.duration - ): - self._robot.move(vector) - await asyncio.sleep(0.001) # Send commands at 1000Hz - # Always stop at the end - self._robot.move(Vector(0, 0, 0)) - finally: - loop.close() - - # Run movement in a separate thread with asyncio event loop - move_thread = threading.Thread(target=lambda: asyncio.run(continuous_move())) - move_thread.daemon = True - move_thread.start() - - # Wait for the full duration - time.sleep(self.duration) - stop_event.set() - move_thread.join(timeout=0.5) # Wait for thread to finish with timeout - else: - # Just execute the move command once for continuous movement - self._robot.move(vector) - return True + # Use move_vel with negative x for backward movement + return self._robot.move_vel(x=-self.x, y=self.y, yaw=self.yaw, duration=self.duration) class SpinLeft(AbstractRobotSkill): """Spin the robot left using degree commands.""" diff --git a/dimos/stream/video_providers/unitree.py b/dimos/stream/video_providers/unitree.py index e91351a229..e1a7587146 100644 --- a/dimos/stream/video_providers/unitree.py +++ b/dimos/stream/video_providers/unitree.py @@ -15,12 +15,7 @@ from dimos.stream.video_provider import AbstractVideoProvider from queue import Queue -from dimos.robot.unitree.external.go2_webrtc_connect.go2_webrtc_driver.constants import ( - WebRTCConnectionMethod, -) -from dimos.robot.unitree.external.go2_webrtc_connect.go2_webrtc_driver.webrtc_driver import ( - Go2WebRTCConnection, -) +from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod from aiortc import MediaStreamTrack import asyncio from reactivex import Observable, create, operators as ops diff --git a/dimos/types/vector.py b/dimos/types/vector.py index eb43c04945..8d3ae1ef91 100644 --- a/dimos/types/vector.py +++ b/dimos/types/vector.py @@ -91,8 +91,6 @@ def __str__(self) -> str: def getArrow(): repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] - print("SELF X", self.x) - print("SELF Y", self.y) if self.x == 0 and self.y == 0: return "·" @@ -109,6 +107,10 @@ def serialize(self) -> Tuple: """Serialize the vector to a tuple.""" return {"type": "vector", "c": self._data.tolist()} + def __len__(self) -> int: + """Return the dimension of the vector.""" + return len(self._data) + def __eq__(self, other) -> bool: """Check if two vectors are equal using numpy's allclose for floating point comparison.""" if not isinstance(other, Vector): diff --git a/requirements.txt b/requirements.txt index bf2c804ba6..79b6393265 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ opencv-python python-dotenv openai anthropic>=0.19.0 +cerebras-cloud-sdk numpy>=1.26.4,<2.0.0 colorlog==6.9.0 yapf==0.40.2 diff --git a/tests/test_unitree_ros_v0.0.4.py b/tests/test_unitree_ros_v0.0.4.py new file mode 100644 index 0000000000..79f47dfef0 --- /dev/null +++ b/tests/test_unitree_ros_v0.0.4.py @@ -0,0 +1,202 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +import time +from dotenv import load_dotenv +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal +from dimos.skills.visual_navigation_skills import FollowHuman +import reactivex as rx +import reactivex.operators as ops +from dimos.stream.audio.pipelines import tts, stt +import threading +import json +from dimos.types.vector import Vector +from dimos.skills.speak import Speak +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.utils.reactive import backpressure + +# Load API key from environment +load_dotenv() + +# Allow command line arguments to control spatial memory parameters +import argparse + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Run the robot with optional spatial memory parameters" + ) + parser.add_argument( + "--spatial-memory-dir", type=str, help="Directory for storing spatial memory data" + ) + parser.add_argument( + "--voice", + action="store_true", + help="Use voice input from microphone instead of web interface", + ) + return parser.parse_args() + + +args = parse_arguments() + +# Initialize robot with spatial memory parameters +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + skills=MyUnitreeSkills(), + mock_connection=False, + spatial_memory_dir=args.spatial_memory_dir, # Will use default if None + new_memory=True, +) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) + +# Initialize object detection stream +min_confidence = 0.6 +class_filter = None # No class filtering +detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) + +# Create video stream from robot's camera +video_stream = backpressure(robot.get_ros_video_stream()) + +# Initialize ObjectDetectionStream with robot +object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + transform_to_map=robot.ros_control.transform_pose, + detector=detector, + video_stream=video_stream, +) + +# Create visualization stream for web interface +viz_stream = backpressure(object_detector.get_stream()).pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), +) + +# Get the formatted detection stream +formatted_detection_stream = object_detector.get_formatted_stream().pipe( + ops.filter(lambda x: x is not None) +) + + +# Create a direct mapping that combines detection data with locations +def combine_with_locations(object_detections): + # Get locations from spatial memory + try: + locations = robot.get_spatial_memory().get_robot_locations() + + # Format the locations section + locations_text = "\n\nSaved Robot Locations:\n" + if locations: + for loc in locations: + locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " + locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" + else: + locations_text += "None\n" + + # Simply concatenate the strings + return object_detections + locations_text + except Exception as e: + print(f"Error adding locations: {e}") + return object_detections + + +# Create the combined stream with a simple pipe operation +enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) + +streams = { + "unitree_video": robot.get_ros_video_stream(), + "local_planner_viz": local_planner_viz_stream, + "object_detection": viz_stream, +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +stt_node = stt() + +# Read system query from prompt.txt file +with open( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt"), "r" +) as f: + system_query = f.read() + +# Create a ClaudeAgent instance with either voice input or web interface input based on flag +input_stream = stt_node.emit_text() if args.voice else web_interface.query_stream +print(f"Using {'voice input' if args.voice else 'web interface input'} for queries") + +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=input_stream, + input_data_stream=enhanced_data_stream, # Add the enhanced data stream + skills=robot.get_skills(), + system_query=system_query, + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=0, +) + +# Initialize TTS node only if voice flag is set +tts_node = None +if args.voice: + print("Voice mode: Enabling TTS for speech output") + tts_node = tts() + tts_node.consume_text(agent.get_response_observable()) +else: + print("Web interface mode: Disabling TTS to avoid audio issues") + +robot_skills = robot.get_skills() +robot_skills.add(ObserveStream) +robot_skills.add(KillSkill) +robot_skills.add(NavigateWithText) +robot_skills.add(FollowHuman) +robot_skills.add(GetPose) +# Add Speak skill only if voice flag is set +if args.voice: + robot_skills.add(Speak) +# robot_skills.add(NavigateToGoal) +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("NavigateWithText", robot=robot) +robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +# robot_skills.create_instance("NavigateToGoal", robot=robot) +# Create Speak skill instance only if voice flag is set +if args.voice: + robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +print("ObserveStream and Kill skills registered and ready for use") +print("Created memory.txt file") + +web_interface.run()