diff --git a/dimos/manipulation/__init__.py b/dimos/manipulation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/manipulation/manipulation_history.py b/dimos/manipulation/manipulation_history.py new file mode 100644 index 0000000000..9196c41710 --- /dev/null +++ b/dimos/manipulation/manipulation_history.py @@ -0,0 +1,404 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for manipulation history tracking and search.""" + +from typing import Dict, List, Optional, Any, Tuple, Union, Set, Callable +from dataclasses import dataclass, field +import time +from datetime import datetime +import os +import json +import pickle +import uuid + +from dimos.types.manipulation import ( + ManipulationTask, + AbstractConstraint, + ManipulationTaskConstraint, + ManipulationMetadata, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.types.manipulation_history") + + +@dataclass +class ManipulationHistoryEntry: + """An entry in the manipulation history. + + Attributes: + task: The manipulation task executed + timestamp: When the manipulation was performed + result: Result of the manipulation (success/failure) + manipulation_response: Response from the motion planner/manipulation executor + """ + + task: ManipulationTask + timestamp: float = field(default_factory=time.time) + result: Dict[str, Any] = field(default_factory=dict) + manipulation_response: Optional[str] = ( + None # Any elaborative response from the motion planner / manipulation executor + ) + + def __str__(self) -> str: + status = self.result.get("status", "unknown") + return f"ManipulationHistoryEntry(task='{self.task.description}', status={status}, time={datetime.fromtimestamp(self.timestamp).strftime('%H:%M:%S')})" + + +class ManipulationHistory: + """A simplified, dictionary-based storage for manipulation history. + + This class provides an efficient way to store and query manipulation tasks, + focusing on quick lookups and flexible search capabilities. + """ + + def __init__(self, output_dir: str = None, new_memory: bool = False): + """Initialize a new manipulation history. + + Args: + output_dir: Directory to save history to + new_memory: If True, creates a new memory instead of loading existing one + """ + self._history: List[ManipulationHistoryEntry] = [] + self._output_dir = output_dir + + if output_dir and not new_memory: + self.load_from_dir(output_dir) + elif output_dir: + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Created new manipulation history at {output_dir}") + + def __len__(self) -> int: + """Return the number of entries in the history.""" + return len(self._history) + + def __str__(self) -> str: + """Return a string representation of the history.""" + if not self._history: + return "ManipulationHistory(empty)" + + return ( + f"ManipulationHistory(entries={len(self._history)}, " + f"time_range={datetime.fromtimestamp(self._history[0].timestamp).strftime('%Y-%m-%d %H:%M:%S')} to " + f"{datetime.fromtimestamp(self._history[-1].timestamp).strftime('%Y-%m-%d %H:%M:%S')})" + ) + + def clear(self) -> None: + """Clear all entries from the history.""" + self._history.clear() + logger.info("Cleared manipulation history") + + if self._output_dir: + self.save_history() + + def add_entry(self, entry: ManipulationHistoryEntry) -> None: + """Add an entry to the history. + + Args: + entry: The entry to add + """ + self._history.append(entry) + self._history.sort(key=lambda e: e.timestamp) + + if self._output_dir: + self.save_history() + + def save_history(self) -> None: + """Save the history to the output directory.""" + if not self._output_dir: + logger.warning("Cannot save history: no output directory specified") + return + + os.makedirs(self._output_dir, exist_ok=True) + history_path = os.path.join(self._output_dir, "manipulation_history.pickle") + + with open(history_path, "wb") as f: + pickle.dump(self._history, f) + + logger.info(f"Saved manipulation history to {history_path}") + + # Also save a JSON representation for easier inspection + json_path = os.path.join(self._output_dir, "manipulation_history.json") + try: + history_data = [ + { + "task": { + "description": entry.task.description, + "target_object": entry.task.target_object, + "target_point": entry.task.target_point, + "timestamp": entry.task.timestamp, + "task_id": entry.task.task_id, + "metadata": entry.task.metadata, + }, + "result": entry.result, + "timestamp": entry.timestamp, + "manipulation_response": entry.manipulation_response, + } + for entry in self._history + ] + + with open(json_path, "w") as f: + json.dump(history_data, f, indent=2) + + logger.info(f"Saved JSON representation to {json_path}") + except Exception as e: + logger.error(f"Failed to save JSON representation: {e}") + + def load_from_dir(self, directory: str) -> None: + """Load history from the specified directory. + + Args: + directory: Directory to load history from + """ + history_path = os.path.join(directory, "manipulation_history.pickle") + + if not os.path.exists(history_path): + logger.warning(f"No history found at {history_path}") + return + + try: + with open(history_path, "rb") as f: + self._history = pickle.load(f) + + logger.info( + f"Loaded manipulation history from {history_path} with {len(self._history)} entries" + ) + except Exception as e: + logger.error(f"Failed to load history: {e}") + + def get_all_entries(self) -> List[ManipulationHistoryEntry]: + """Get all entries in chronological order. + + Returns: + List of all manipulation history entries + """ + return self._history.copy() + + def get_entry_by_index(self, index: int) -> Optional[ManipulationHistoryEntry]: + """Get an entry by its index. + + Args: + index: Index of the entry to retrieve + + Returns: + The entry at the specified index or None if index is out of bounds + """ + if 0 <= index < len(self._history): + return self._history[index] + return None + + def get_entries_by_timerange( + self, start_time: float, end_time: float + ) -> List[ManipulationHistoryEntry]: + """Get entries within a specific time range. + + Args: + start_time: Start time (UNIX timestamp) + end_time: End time (UNIX timestamp) + + Returns: + List of entries within the specified time range + """ + return [entry for entry in self._history if start_time <= entry.timestamp <= end_time] + + def get_entries_by_object(self, object_name: str) -> List[ManipulationHistoryEntry]: + """Get entries related to a specific object. + + Args: + object_name: Name of the object to search for + + Returns: + List of entries related to the specified object + """ + return [entry for entry in self._history if entry.task.target_object == object_name] + + def create_task_entry( + self, task: ManipulationTask, result: Dict[str, Any] = None, agent_response: str = None + ) -> ManipulationHistoryEntry: + """Create a new manipulation history entry. + + Args: + task: The manipulation task + result: Result of the manipulation + agent_response: Response from the agent about this manipulation + + Returns: + The created history entry + """ + entry = ManipulationHistoryEntry( + task=task, result=result or {}, manipulation_response=agent_response + ) + self.add_entry(entry) + return entry + + def search(self, **kwargs) -> List[ManipulationHistoryEntry]: + """Flexible search method that can search by any field in ManipulationHistoryEntry using dot notation. + + This method supports dot notation to access nested fields. String values automatically use + substring matching (contains), while all other types use exact matching. + + Examples: + # Time-based searches: + - search(**{"task.metadata.timestamp": ('>', start_time)}) - entries after start_time + - search(**{"task.metadata.timestamp": ('>=', time - 1800)}) - entries in last 30 mins + + # Constraint searches: + - search(**{"task.constraints.*.reference_point.x": 2.5}) - tasks with x=2.5 reference point + - search(**{"task.constraints.*.end_angle.x": 90}) - tasks with 90-degree x rotation + - search(**{"task.constraints.*.lock_x": True}) - tasks with x-axis translation locked + + # Object and result searches: + - search(**{"task.metadata.objects.*.label": "cup"}) - tasks involving cups + - search(**{"result.status": "success"}) - successful tasks + - search(**{"result.error": "Collision"}) - tasks that had collisions + + Args: + **kwargs: Key-value pairs for searching using dot notation for field paths. + + Returns: + List of matching entries + """ + if not kwargs: + return self._history.copy() + + results = self._history.copy() + + for key, value in kwargs.items(): + # For all searches, automatically determine if we should use contains for strings + results = [e for e in results if self._check_field_match(e, key, value)] + + return results + + def _check_field_match(self, entry, field_path, value) -> bool: + """Check if a field matches the value, with special handling for strings, collections and comparisons. + + For string values, we automatically use substring matching (contains). + For collections (returned by * path), we check if any element matches. + For numeric values (like timestamps), supports >, <, >= and <= comparisons. + For all other types, we use exact matching. + + Args: + entry: The entry to check + field_path: Dot-separated path to the field + value: Value to match against. For comparisons, use tuples like: + ('>', timestamp) - greater than + ('<', timestamp) - less than + ('>=', timestamp) - greater or equal + ('<=', timestamp) - less or equal + + Returns: + True if the field matches the value, False otherwise + """ + try: + field_value = self._get_value_by_path(entry, field_path) + + # Handle comparison operators for timestamps and numbers + if isinstance(value, tuple) and len(value) == 2: + op, compare_value = value + if op == ">": + return field_value > compare_value + elif op == "<": + return field_value < compare_value + elif op == ">=": + return field_value >= compare_value + elif op == "<=": + return field_value <= compare_value + + # Handle lists (from collection searches) + if isinstance(field_value, list): + for item in field_value: + # String values use contains matching + if isinstance(item, str) and isinstance(value, str): + if value in item: + return True + # All other types use exact matching + elif item == value: + return True + return False + + # String values use contains matching + elif isinstance(field_value, str) and isinstance(value, str): + return value in field_value + # All other types use exact matching + else: + return field_value == value + + except (AttributeError, KeyError): + return False + + def _get_value_by_path(self, obj, path): + """Get a value from an object using a dot-separated path. + + This method handles three special cases: + 1. Regular attribute access (obj.attr) + 2. Dictionary key access (dict[key]) + 3. Collection search (dict.*.attr) - when * is used, it searches all values in the collection + + Args: + obj: Object to get value from + path: Dot-separated path to the field (e.g., "task.metadata.robot") + + Returns: + Value at the specified path or list of values for collection searches + + Raises: + AttributeError: If an attribute in the path doesn't exist + KeyError: If a dictionary key in the path doesn't exist + """ + current = obj + parts = path.split(".") + + for i, part in enumerate(parts): + # Collection search (*.attr) - search across all items in a collection + if part == "*": + # Get remaining path parts + remaining_path = ".".join(parts[i + 1 :]) + + # Handle different collection types + if isinstance(current, dict): + items = current.values() + if not remaining_path: # If * is the last part, return all values + return list(items) + elif isinstance(current, list): + items = current + if not remaining_path: # If * is the last part, return all items + return items + else: # Not a collection + raise AttributeError( + f"Cannot use wildcard on non-collection type: {type(current)}" + ) + + # Apply remaining path to each item in the collection + results = [] + for item in items: + try: + # Recursively get values from each item + value = self._get_value_by_path(item, remaining_path) + if isinstance(value, list): # Flatten nested lists + results.extend(value) + else: + results.append(value) + except (AttributeError, KeyError): + # Skip items that don't have the attribute + pass + return results + + # Regular attribute/key access + elif isinstance(current, dict): + current = current[part] + else: + current = getattr(current, part) + + return current diff --git a/dimos/manipulation/manipulation_interface.py b/dimos/manipulation/manipulation_interface.py new file mode 100644 index 0000000000..68d3924a99 --- /dev/null +++ b/dimos/manipulation/manipulation_interface.py @@ -0,0 +1,292 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ManipulationInterface provides a unified interface for accessing manipulation history. + +This module defines the ManipulationInterface class, which serves as an access point +for the robot's manipulation history, agent-generated constraints, and manipulation +metadata streams. +""" + +from typing import Dict, List, Optional, Any, Tuple, Union +from dataclasses import dataclass +import os +import time +from datetime import datetime +from reactivex.disposable import Disposable +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.types.manipulation import ( + AbstractConstraint, + TranslationConstraint, + RotationConstraint, + ForceConstraint, + ManipulationTaskConstraint, + ManipulationTask, + ManipulationMetadata, + ObjectData, +) +from dimos.manipulation.manipulation_history import ( + ManipulationHistory, + ManipulationHistoryEntry, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.manipulation_interface") + + +class ManipulationInterface: + """ + Interface for accessing and managing robot manipulation data. + + This class provides a unified interface for managing manipulation tasks and constraints. + It maintains a list of constraints generated by the Agent and provides methods to + add and manage manipulation tasks. + """ + + def __init__( + self, + output_dir: str, + new_memory: bool = False, + perception_stream: ObjectDetectionStream = None, + ): + """ + Initialize a new ManipulationInterface instance. + + Args: + output_dir: Directory for storing manipulation data + new_memory: If True, creates a new manipulation history from scratch + perception_stream: ObjectDetectionStream instance for real-time object data + """ + self.output_dir = output_dir + + # Create manipulation history directory + manipulation_dir = os.path.join(output_dir, "manipulation_history") + os.makedirs(manipulation_dir, exist_ok=True) + + # Initialize manipulation history + self.manipulation_history: ManipulationHistory = ManipulationHistory( + output_dir=manipulation_dir, new_memory=new_memory + ) + + # List of constraints generated by the Agent via constraint generation skills + self.agent_constraints: List[AbstractConstraint] = [] + + # Initialize object detection stream and related properties + self.perception_stream = perception_stream + self.latest_objects: List[ObjectData] = [] + self.stream_subscription: Optional[Disposable] = None + + # Set up subscription to perception stream if available + self._setup_perception_subscription() + + logger.info("ManipulationInterface initialized") + + def add_constraint(self, constraint: AbstractConstraint) -> None: + """ + Add a constraint generated by the Agent via a constraint generation skill. + + Args: + constraint: The constraint to add to agent_constraints + """ + self.agent_constraints.append(constraint) + logger.info(f"Added agent constraint: {constraint}") + + def get_constraints(self) -> List[AbstractConstraint]: + """ + Get all constraints generated by the Agent via constraint generation skills. + + Returns: + List of all constraints created by the Agent + """ + return self.agent_constraints + + def get_constraint(self, constraint_id: str) -> Optional[AbstractConstraint]: + """ + Get a specific constraint by its ID. + + Args: + constraint_id: ID of the constraint to retrieve + + Returns: + The matching constraint or None if not found + """ + # Find constraint with matching ID + for constraint in self.agent_constraints: + if constraint.id == constraint_id: + return constraint + + logger.warning(f"Constraint with ID {constraint_id} not found") + return None + + def add_manipulation_task( + self, task: ManipulationTask, manipulation_response: Optional[str] = None + ) -> None: + """ + Add a manipulation task to ManipulationHistory. + + Args: + task: The ManipulationTask to add + manipulation_response: Optional response from the motion planner/executor + + """ + # Add task to history + self.manipulation_history.add_entry( + task=task, result=None, notes=None, manipulation_response=manipulation_response + ) + + def get_manipulation_task(self, task_id: str) -> Optional[ManipulationTask]: + """ + Get a manipulation task by its ID. + + Args: + task_id: ID of the task to retrieve + + Returns: + The task object or None if not found + """ + return self.history.get_manipulation_task(task_id) + + def get_all_manipulation_tasks(self) -> List[ManipulationTask]: + """ + Get all manipulation tasks. + + Returns: + List of all manipulation tasks + """ + return self.history.get_all_manipulation_tasks() + + def update_task_status( + self, task_id: str, status: str, result: Optional[Dict[str, Any]] = None + ) -> Optional[ManipulationTask]: + """ + Update the status and result of a manipulation task. + + Args: + task_id: ID of the task to update + status: New status for the task (e.g., 'completed', 'failed') + result: Optional dictionary with result data + + Returns: + The updated task or None if task not found + """ + return self.history.update_task_status(task_id, status, result) + + # === Perception stream methods === + + def _setup_perception_subscription(self): + """ + Set up subscription to perception stream if available. + """ + if self.perception_stream: + # Subscribe to the stream and update latest_objects + self.stream_subscription = self.perception_stream.get_stream().subscribe( + on_next=self._update_latest_objects, + on_error=lambda e: logger.error(f"Error in perception stream: {e}"), + ) + logger.info("Subscribed to perception stream") + + def _update_latest_objects(self, data): + """ + Update the latest detected objects. + + Args: + data: Data from the object detection stream + """ + if "objects" in data: + self.latest_objects = data["objects"] + + def get_latest_objects(self) -> List[ObjectData]: + """ + Get the latest detected objects from the stream. + + Returns: + List of the most recently detected objects + """ + return self.latest_objects + + def get_object_by_id(self, object_id: int) -> Optional[ObjectData]: + """ + Get a specific object by its tracking ID. + + Args: + object_id: Tracking ID of the object + + Returns: + The object data or None if not found + """ + for obj in self.latest_objects: + if obj["object_id"] == object_id: + return obj + return None + + def get_objects_by_label(self, label: str) -> List[ObjectData]: + """ + Get all objects with a specific label. + + Args: + label: Class label to filter objects by + + Returns: + List of objects matching the label + """ + return [obj for obj in self.latest_objects if obj["label"] == label] + + def set_perception_stream(self, perception_stream): + """ + Set or update the perception stream. + + Args: + perception_stream: The PerceptionStream instance + """ + # Clean up existing subscription if any + self.cleanup_perception_subscription() + + # Set new stream and subscribe + self.perception_stream = perception_stream + self._setup_perception_subscription() + + def cleanup_perception_subscription(self): + """ + Clean up the stream subscription. + """ + if self.stream_subscription: + self.stream_subscription.dispose() + self.stream_subscription = None + + # === Utility methods === + + def clear_history(self) -> None: + """ + Clear all manipulation history data and agent constraints. + """ + self.manipulation_history.clear() + self.agent_constraints.clear() + logger.info("Cleared manipulation history and agent constraints") + + def __str__(self) -> str: + """ + String representation of the manipulation interface. + + Returns: + String representation with key stats + """ + has_stream = self.perception_stream is not None + return f"ManipulationInterface(history={self.manipulation_history}, agent_constraints={len(self.agent_constraints)}, perception_stream={has_stream}, detected_objects={len(self.latest_objects)})" + + def __del__(self): + """ + Clean up resources on deletion. + """ + self.cleanup_perception_subscription() diff --git a/dimos/manipulation/test_manipulation_history.py b/dimos/manipulation/test_manipulation_history.py new file mode 100644 index 0000000000..6057c5d985 --- /dev/null +++ b/dimos/manipulation/test_manipulation_history.py @@ -0,0 +1,447 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import tempfile +import pytest +from typing import Dict, List, Optional, Any, Tuple + +from dimos.manipulation.manipulation_history import ManipulationHistory, ManipulationHistoryEntry +from dimos.types.manipulation import ( + ManipulationTask, + AbstractConstraint, + TranslationConstraint, + RotationConstraint, + ForceConstraint, + ManipulationTaskConstraint, + ManipulationMetadata, +) +from dimos.types.vector import Vector + + +@pytest.fixture +def sample_task(): + """Create a sample manipulation task for testing.""" + return ManipulationTask( + description="Pick up the cup", + target_object="cup", + target_point=(100, 200), + task_id="task1", + metadata={ + "timestamp": time.time(), + "objects": { + "cup1": { + "object_id": 1, + "label": "cup", + "confidence": 0.95, + "position": {"x": 1.5, "y": 2.0, "z": 0.5}, + }, + "table1": { + "object_id": 2, + "label": "table", + "confidence": 0.98, + "position": {"x": 0.0, "y": 0.0, "z": 0.0}, + }, + }, + }, + ) + + +@pytest.fixture +def sample_task_with_constraints(): + """Create a sample manipulation task with constraints for testing.""" + task = ManipulationTask( + description="Rotate the bottle", + target_object="bottle", + target_point=(150, 250), + task_id="task2", + metadata={ + "timestamp": time.time(), + "objects": { + "bottle1": { + "object_id": 3, + "label": "bottle", + "confidence": 0.92, + "position": {"x": 2.5, "y": 1.0, "z": 0.3}, + } + }, + }, + ) + + # Add rich translation constraint + translation_constraint = TranslationConstraint( + translation_axis="y", + reference_point=Vector(2.5, 1.0, 0.3), + bounds_min=Vector(2.0, 0.5, 0.3), + bounds_max=Vector(3.0, 1.5, 0.3), + target_point=Vector(2.7, 1.2, 0.3), + description="Constrained translation along Y-axis only", + ) + task.add_constraint(translation_constraint) + + # Add rich rotation constraint + rotation_constraint = RotationConstraint( + rotation_axis="roll", + start_angle=Vector(0, 0, 0), + end_angle=Vector(90, 0, 0), + pivot_point=Vector(2.5, 1.0, 0.3), + secondary_pivot_point=Vector(2.5, 1.0, 0.5), + description="Constrained rotation around X-axis (roll only)", + ) + task.add_constraint(rotation_constraint) + + # Add force constraint + force_constraint = ForceConstraint( + min_force=2.0, + max_force=5.0, + force_direction=Vector(0, 0, -1), + description="Apply moderate downward force during manipulation", + ) + task.add_constraint(force_constraint) + + return task + + +@pytest.fixture +def temp_output_dir(): + """Create a temporary directory for testing history saving/loading.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def populated_history(sample_task, sample_task_with_constraints): + """Create a populated history with multiple entries for testing.""" + history = ManipulationHistory() + + # Add first entry + entry1 = ManipulationHistoryEntry( + task=sample_task, + result={"status": "success", "execution_time": 2.5}, + manipulation_response="Successfully picked up the cup", + ) + history.add_entry(entry1) + + # Add second entry + entry2 = ManipulationHistoryEntry( + task=sample_task_with_constraints, + result={"status": "failure", "error": "Collision detected"}, + manipulation_response="Failed to rotate the bottle due to collision", + ) + history.add_entry(entry2) + + return history + + +def test_manipulation_history_init(): + """Test initialization of ManipulationHistory.""" + # Default initialization + history = ManipulationHistory() + assert len(history) == 0 + assert str(history) == "ManipulationHistory(empty)" + + # With output directory + with tempfile.TemporaryDirectory() as temp_dir: + history = ManipulationHistory(output_dir=temp_dir, new_memory=True) + assert len(history) == 0 + assert os.path.exists(temp_dir) + + +def test_manipulation_history_add_entry(sample_task): + """Test adding entries to ManipulationHistory.""" + history = ManipulationHistory() + + # Create and add entry + entry = ManipulationHistoryEntry( + task=sample_task, result={"status": "success"}, manipulation_response="Task completed" + ) + history.add_entry(entry) + + assert len(history) == 1 + assert history.get_entry_by_index(0) == entry + + +def test_manipulation_history_create_task_entry(sample_task): + """Test creating a task entry directly.""" + history = ManipulationHistory() + + entry = history.create_task_entry( + task=sample_task, result={"status": "success"}, agent_response="Task completed" + ) + + assert len(history) == 1 + assert entry.task == sample_task + assert entry.result["status"] == "success" + assert entry.manipulation_response == "Task completed" + + +def test_manipulation_history_save_load(temp_output_dir, sample_task): + """Test saving and loading history from disk.""" + # Create history and add entry + history = ManipulationHistory(output_dir=temp_output_dir) + entry = history.create_task_entry( + task=sample_task, result={"status": "success"}, agent_response="Task completed" + ) + + # Check that files were created + pickle_path = os.path.join(temp_output_dir, "manipulation_history.pickle") + json_path = os.path.join(temp_output_dir, "manipulation_history.json") + assert os.path.exists(pickle_path) + assert os.path.exists(json_path) + + # Create new history that loads from the saved files + loaded_history = ManipulationHistory(output_dir=temp_output_dir) + assert len(loaded_history) == 1 + assert loaded_history.get_entry_by_index(0).task.description == sample_task.description + + +def test_manipulation_history_clear(populated_history): + """Test clearing the history.""" + assert len(populated_history) > 0 + + populated_history.clear() + assert len(populated_history) == 0 + assert str(populated_history) == "ManipulationHistory(empty)" + + +def test_manipulation_history_get_methods(populated_history): + """Test various getter methods of ManipulationHistory.""" + # get_all_entries + entries = populated_history.get_all_entries() + assert len(entries) == 2 + + # get_entry_by_index + entry = populated_history.get_entry_by_index(0) + assert entry.task.task_id == "task1" + + # Out of bounds index + assert populated_history.get_entry_by_index(100) is None + + # get_entries_by_timerange + start_time = time.time() - 3600 # 1 hour ago + end_time = time.time() + 3600 # 1 hour from now + entries = populated_history.get_entries_by_timerange(start_time, end_time) + assert len(entries) == 2 + + # get_entries_by_object + cup_entries = populated_history.get_entries_by_object("cup") + assert len(cup_entries) == 1 + assert cup_entries[0].task.task_id == "task1" + + bottle_entries = populated_history.get_entries_by_object("bottle") + assert len(bottle_entries) == 1 + assert bottle_entries[0].task.task_id == "task2" + + +def test_manipulation_history_search_basic(populated_history): + """Test basic search functionality.""" + # Search by exact match on top-level fields + results = populated_history.search(timestamp=populated_history.get_entry_by_index(0).timestamp) + assert len(results) == 1 + + # Search by task fields + results = populated_history.search(**{"task.task_id": "task1"}) + assert len(results) == 1 + assert results[0].task.target_object == "cup" + + # Search by result fields + results = populated_history.search(**{"result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by manipulation_response (substring match for strings) + results = populated_history.search(manipulation_response="picked up") + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_nested(populated_history): + """Test search with nested field paths.""" + # Search by nested metadata fields + results = populated_history.search( + **{ + "task.metadata.timestamp": populated_history.get_entry_by_index(0).task.metadata[ + "timestamp" + ] + } + ) + assert len(results) == 1 + + # Search by nested object fields + results = populated_history.search(**{"task.metadata.objects.cup1.label": "cup"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by position values + results = populated_history.search(**{"task.metadata.objects.cup1.position.x": 1.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_wildcards(populated_history): + """Test search with wildcard patterns.""" + # Search for any object with label "cup" + results = populated_history.search(**{"task.metadata.objects.*.label": "cup"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for any object with confidence > 0.95 + results = populated_history.search(**{"task.metadata.objects.*.confidence": 0.98}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for any object position with x=2.5 + results = populated_history.search(**{"task.metadata.objects.*.position.x": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_constraints(populated_history): + """Test search by constraint properties.""" + # Find entries with any TranslationConstraint with y-axis + results = populated_history.search(**{"task.constraints.*.translation_axis": "y"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Find entries with any RotationConstraint with roll axis + results = populated_history.search(**{"task.constraints.*.rotation_axis": "roll"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_string_contains(populated_history): + """Test string contains searching.""" + # Basic string contains + results = populated_history.search(**{"task.description": "Pick"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Nested string contains + results = populated_history.search(manipulation_response="collision") + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_multiple_criteria(populated_history): + """Test search with multiple criteria.""" + # Multiple criteria - all must match + results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Multiple criteria with no matches + results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) + assert len(results) == 0 + + # Combination of direct and wildcard paths + results = populated_history.search( + **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} + ) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_nonexistent_fields(populated_history): + """Test search with fields that don't exist.""" + # Search by nonexistent field + results = populated_history.search(nonexistent_field="value") + assert len(results) == 0 + + # Search by nonexistent nested field + results = populated_history.search(**{"task.nonexistent_field": "value"}) + assert len(results) == 0 + + # Search by nonexistent object + results = populated_history.search(**{"task.metadata.objects.nonexistent_object": "value"}) + assert len(results) == 0 + + +def test_manipulation_history_search_timestamp_ranges(populated_history): + """Test searching by timestamp ranges.""" + # Get reference timestamps + entry1_time = populated_history.get_entry_by_index(0).task.metadata["timestamp"] + entry2_time = populated_history.get_entry_by_index(1).task.metadata["timestamp"] + mid_time = (entry1_time + entry2_time) / 2 + + # Search for timestamps before second entry + results = populated_history.search(**{"task.metadata.timestamp": ("<", entry2_time)}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for timestamps after first entry + results = populated_history.search(**{"task.metadata.timestamp": (">", entry1_time)}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search within a time window using >= and <= + results = populated_history.search(**{"task.metadata.timestamp": (">=", mid_time - 1800)}) + assert len(results) == 2 + assert results[0].task.task_id == "task1" + assert results[1].task.task_id == "task2" + + +def test_manipulation_history_search_vector_fields(populated_history): + """Test searching by vector components in constraints.""" + # Search by reference point components + results = populated_history.search(**{"task.constraints.*.reference_point.x": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by target point components + results = populated_history.search(**{"task.constraints.*.target_point.z": 0.3}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by rotation angles + results = populated_history.search(**{"task.constraints.*.end_angle.x": 90}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_execution_details(populated_history): + """Test searching by execution time and error patterns.""" + # Search by execution time + results = populated_history.search(**{"result.execution_time": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by error message pattern + results = populated_history.search(**{"result.error": "Collision"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by status + results = populated_history.search(**{"result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_multiple_criteria(populated_history): + """Test search with multiple criteria.""" + # Multiple criteria - all must match + results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Multiple criteria with no matches + results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) + assert len(results) == 0 + + # Combination of direct and wildcard paths + results = populated_history.search( + **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} + ) + assert len(results) == 1 + assert results[0].task.task_id == "task2" diff --git a/dimos/perception/detection2d/detic_2d_det.py b/dimos/perception/detection2d/detic_2d_det.py index bed5700521..ff0a1ad979 100644 --- a/dimos/perception/detection2d/detic_2d_det.py +++ b/dimos/perception/detection2d/detic_2d_det.py @@ -27,7 +27,7 @@ def __init__(self, iou_threshold=0.3, max_age=5): self.iou_threshold = iou_threshold self.max_age = max_age self.next_id = 1 - self.tracks = {} # id -> {bbox, class_id, age, etc} + self.tracks = {} # id -> {bbox, class_id, age, mask, etc} def _calculate_iou(self, bbox1, bbox2): """Calculate IoU between two bboxes in format [x1,y1,x2,y2]""" @@ -46,14 +46,15 @@ def _calculate_iou(self, bbox1, bbox2): return intersection / union if union > 0 else 0 - def update(self, detections): + def update(self, detections, masks): """Update tracker with new detections Args: detections: List of [x1,y1,x2,y2,score,class_id] + masks: List of segmentation masks corresponding to detections Returns: - List of [track_id, bbox, score, class_id] + List of [track_id, bbox, score, class_id, mask] """ if len(detections) == 0: # Age existing tracks @@ -101,15 +102,17 @@ def update(self, detections): self.tracks[track_id]["bbox"] = detections[best_idx][:4] self.tracks[track_id]["score"] = detections[best_idx][4] self.tracks[track_id]["age"] = 0 + self.tracks[track_id]["mask"] = masks[best_idx] matched_indices.add(best_idx) - # Add to results + # Add to results with mask result.append( [ track_id, detections[best_idx][:4], detections[best_idx][4], int(detections[best_idx][5]), + self.tracks[track_id]["mask"], ] ) @@ -127,10 +130,11 @@ def update(self, detections): "score": det[4], "class_id": int(det[5]), "age": 0, + "mask": masks[i], } - # Add to results - result.append([new_id, det[:4], det[4], int(det[5])]) + # Add to results with mask directly from the track + result.append([new_id, det[:4], det[4], int(det[5]), masks[i]]) return result @@ -301,24 +305,26 @@ def process_image(self, image): image: Input image in BGR format (OpenCV) Returns: - tuple: (bboxes, track_ids, class_ids, confidences, names) + tuple: (bboxes, track_ids, class_ids, confidences, names, masks) - bboxes: list of [x1, y1, x2, y2] coordinates - track_ids: list of tracking IDs (or -1 if no tracking) - class_ids: list of class indices - confidences: list of detection confidences - names: list of class names + - masks: list of segmentation masks (numpy arrays) """ # Run inference with Detic outputs = self.predictor(image) instances = outputs["instances"].to("cpu") - # Extract bounding boxes, classes, and scores + # Extract bounding boxes, classes, scores, and masks if len(instances) == 0: - return [], [], [], [], [] + return [], [], [], [], [], [] boxes = instances.pred_boxes.tensor.numpy() class_ids = instances.pred_classes.numpy() scores = instances.scores.numpy() + masks = instances.pred_masks.numpy() # Convert boxes to [x1, y1, x2, y2] format bboxes = [] @@ -331,16 +337,18 @@ def process_image(self, image): # Apply tracking detections = [] + filtered_masks = [] for i, bbox in enumerate(bboxes): if scores[i] >= self.threshold: # Format for tracker: [x1, y1, x2, y2, score, class_id] detections.append(bbox + [scores[i], class_ids[i]]) + filtered_masks.append(masks[i]) if not detections: - return [], [], [], [], [] + return [], [], [], [], [], [] - # Update tracker with detections - track_results = self.tracker.update(detections) + # Update tracker with detections and correctly aligned masks + track_results = self.tracker.update(detections, filtered_masks) # Process tracking results track_ids = [] @@ -348,15 +356,24 @@ def process_image(self, image): tracked_class_ids = [] tracked_scores = [] tracked_names = [] + tracked_masks = [] - for track_id, bbox, score, class_id in track_results: + for track_id, bbox, score, class_id, mask in track_results: track_ids.append(int(track_id)) tracked_bboxes.append(bbox.tolist() if isinstance(bbox, np.ndarray) else bbox) tracked_class_ids.append(int(class_id)) tracked_scores.append(score) tracked_names.append(self.class_names[int(class_id)]) - - return tracked_bboxes, track_ids, tracked_class_ids, tracked_scores, tracked_names + tracked_masks.append(mask) + + return ( + tracked_bboxes, + track_ids, + tracked_class_ids, + tracked_scores, + tracked_names, + tracked_masks, + ) def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, names): """ diff --git a/dimos/perception/object_detection_stream.py b/dimos/perception/object_detection_stream.py index daa9fa8273..db4eabf66a 100644 --- a/dimos/perception/object_detection_stream.py +++ b/dimos/perception/object_detection_stream.py @@ -13,6 +13,12 @@ ) from dimos.types.vector import Vector from typing import Optional, Union +from dimos.types.manipulation import ObjectData + +from dimos.utils.logging_config import setup_logger + +# Initialize logger for the ObjectDetectionStream +logger = setup_logger("dimos.perception.object_detection_stream") class ObjectDetectionStream: @@ -22,6 +28,7 @@ class ObjectDetectionStream: 2. Estimates depth using Metric3D 3. Calculates 3D position and dimensions using camera intrinsics 4. Transforms coordinates to map frame + 5. Draws bounding boxes and segmentation masks on the frame Provides a stream of structured object data with position and rotation information. """ @@ -31,11 +38,13 @@ def __init__( camera_intrinsics=None, # [fx, fy, cx, cy] device="cuda", gt_depth_scale=1000.0, - min_confidence=0.5, + min_confidence=0.7, class_filter=None, # Optional list of class names to filter (e.g., ["person", "car"]) transform_to_map=None, # Optional function to transform coordinates to map frame detector: Optional[Union[Detic2DDetector, Yolo2DDetector]] = None, video_stream: Observable = None, + disable_depth: bool = False, # Flag to disable monocular Metric3D depth estimation + draw_masks: bool = False, # Flag to enable drawing segmentation masks ): """ Initialize the ObjectDetectionStream. @@ -49,26 +58,36 @@ def __init__( transform_to_map: Optional function to transform pose to map coordinates detector: Optional detector instance (Detic or Yolo) video_stream: Observable of video frames to process (if provided, returns a stream immediately) + disable_depth: Flag to disable monocular Metric3D depth estimation + draw_masks: Flag to enable drawing segmentation masks """ self.min_confidence = min_confidence self.class_filter = class_filter self.transform_to_map = transform_to_map + self.disable_depth = disable_depth + self.draw_masks = draw_masks # Initialize object detector self.detector = detector or Detic2DDetector(vocabulary=None, threshold=min_confidence) + # Set up camera intrinsics + self.camera_intrinsics = camera_intrinsics # Initialize depth estimation model - self.depth_model = Metric3D(gt_depth_scale) + self.depth_model = None + if not disable_depth: + self.depth_model = Metric3D(gt_depth_scale) - # Set up camera intrinsics - self.camera_intrinsics = camera_intrinsics - if camera_intrinsics is not None: - self.depth_model.update_intrinsic(camera_intrinsics) + if camera_intrinsics is not None: + self.depth_model.update_intrinsic(camera_intrinsics) - # Create 3x3 camera matrix for calculations - fx, fy, cx, cy = camera_intrinsics - self.camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + # Create 3x3 camera matrix for calculations + fx, fy, cx, cy = camera_intrinsics + self.camera_matrix = np.array( + [[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32 + ) + else: + raise ValueError("camera_intrinsics must be provided") else: - raise ValueError("camera_intrinsics must be provided") + logger.info("Depth estimation disabled") # If video_stream is provided, create and store the stream immediately self.stream = None @@ -89,7 +108,9 @@ def create_stream(self, video_stream: Observable) -> Observable: def process_frame(frame): # Detect objects - bboxes, track_ids, class_ids, confidences, names = self.detector.process_image(frame) + bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image( + frame + ) # Create visualization viz_frame = frame.copy() @@ -107,36 +128,44 @@ def process_frame(frame): if self.class_filter and class_name not in self.class_filter: continue - # Get depth for this object - depth = calculate_depth_from_bbox(self.depth_model, frame, bbox) - if depth is None: - # Skip objects with invalid depth - continue - - # Calculate object position and rotation - position, rotation = calculate_position_rotation_from_bbox( - bbox, depth, self.camera_intrinsics - ) - - # Get object dimensions - width, height = calculate_object_size_from_bbox(bbox, depth, self.camera_intrinsics) - - # Transform to map frame if a transform function is provided - try: - if self.transform_to_map: - position = Vector([position["x"], position["y"], position["z"]]) - rotation = Vector([rotation["roll"], rotation["pitch"], rotation["yaw"]]) - position, rotation = self.transform_to_map( - position, rotation, source_frame="base_link" - ) - position = dict(x=position.x, y=position.y, z=position.z) - rotation = dict(roll=rotation.x, pitch=rotation.y, yaw=rotation.z) - except Exception as e: - print(f"Error transforming to map frame: {e}") - position, rotation = position, rotation - - # Create object data dictionary - object_data = { + if not self.disable_depth: + # Get depth for this object + depth = calculate_depth_from_bbox(self.depth_model, frame, bbox) + if depth is None: + # Skip objects with invalid depth + continue + # Calculate object position and rotation + position, rotation = calculate_position_rotation_from_bbox( + bbox, depth, self.camera_intrinsics + ) + # Get object dimensions + width, height = calculate_object_size_from_bbox( + bbox, depth, self.camera_intrinsics + ) + + # Transform to map frame if a transform function is provided + try: + if self.transform_to_map: + position = Vector([position["x"], position["y"], position["z"]]) + rotation = Vector( + [rotation["roll"], rotation["pitch"], rotation["yaw"]] + ) + position, rotation = self.transform_to_map( + position, rotation, source_frame="base_link" + ) + except Exception as e: + logger.error(f"Error transforming to map frame: {e}") + position, rotation = position, rotation + + else: + depth = -1 + position = Vector(0, 0, 0) + rotation = Vector(0, 0, 0) + width = -1 + height = -1 + + # Create a properly typed ObjectData instance + object_data: ObjectData = { "object_id": track_ids[i] if i < len(track_ids) else -1, "bbox": bbox, "depth": depth, @@ -146,6 +175,7 @@ def process_frame(frame): "position": position, "rotation": rotation, "size": {"width": width, "height": height}, + "segmentation_mask": masks[i], } objects.append(object_data) @@ -153,35 +183,65 @@ def process_frame(frame): # Add visualization x1, y1, x2, y2 = map(int, bbox) color = (0, 255, 0) # Green for detected objects + mask_color = (0, 200, 200) # Yellow-green for masks - # Draw bounding box - cv2.rectangle(viz_frame, (x1, y1), (x2, y2), color, 2) - - # Add text for class and position - text = f"{class_name}: {depth:.2f}m" - pos_text = f"Pos: ({position['x']:.2f}, {position['y']:.2f})" - - # Draw text background - text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] - cv2.rectangle( - viz_frame, (x1, y1 - text_size[1] - 5), (x1 + text_size[0], y1), (0, 0, 0), -1 - ) + # Draw segmentation mask if available and valid + try: + if self.draw_masks and object_data["segmentation_mask"] is not None: + # Create a colored mask overlay + mask = object_data["segmentation_mask"].astype(np.uint8) + colored_mask = np.zeros_like(viz_frame) + colored_mask[mask > 0] = mask_color + + # Apply the mask with transparency + alpha = 0.5 # transparency factor + mask_area = mask > 0 + viz_frame[mask_area] = cv2.addWeighted( + viz_frame[mask_area], 1 - alpha, colored_mask[mask_area], alpha, 0 + ) - # Draw text - cv2.putText( - viz_frame, text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2 - ) + # Draw mask contour + contours, _ = cv2.findContours( + mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + cv2.drawContours(viz_frame, contours, -1, mask_color, 2) + except Exception as e: + logger.warning(f"Error drawing segmentation mask: {e}") - # Position text below - cv2.putText( - viz_frame, - pos_text, - (x1, y1 + 15), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), - 2, - ) + # Draw bounding box with metadata + try: + cv2.rectangle(viz_frame, (x1, y1), (x2, y2), color, 1) + + # Add text for class only (removed position data) + # Handle possible None values for class_name or track_ids[i] + class_text = class_name if class_name is not None else "Unknown" + id_text = ( + track_ids[i] if i < len(track_ids) and track_ids[i] is not None else "?" + ) + text = f"{class_text}, ID: {id_text}" + + # Draw text background with smaller font + text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.3, 1)[0] + cv2.rectangle( + viz_frame, + (x1, y1 - text_size[1] - 5), + (x1 + text_size[0], y1), + (0, 0, 0), + -1, + ) + + # Draw text with smaller font + cv2.putText( + viz_frame, + text, + (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, + (255, 255, 255), + 1, + ) + except Exception as e: + logger.warning(f"Error drawing bounding box or text: {e}") return {"frame": frame, "viz_frame": viz_frame, "objects": objects} @@ -224,26 +284,30 @@ def format_detection_data(result): return "No objects detected." formatted_data = "[DETECTED OBJECTS]\n" - - for i, obj in enumerate(objects): - pos = obj["position"] - rot = obj["rotation"] - size = obj["size"] - bbox = obj["bbox"] - - # Format each object with a multiline f-string for better readability - bbox_str = f"[{int(bbox[0])}, {int(bbox[1])}, {int(bbox[2])}, {int(bbox[3])}]" - formatted_data += ( - f"Object {i + 1}: {obj['label']}\n" - f" ID: {obj['object_id']}\n" - f" Confidence: {obj['confidence']:.2f}\n" - f" Position: x={pos['x']:.2f}m, y={pos['y']:.2f}m, z={pos['z']:.2f}m\n" - f" Rotation: yaw={rot['yaw']:.2f} rad\n" - f" Size: width={size['width']:.2f}m, height={size['height']:.2f}m\n" - f" Depth: {obj['depth']:.2f}m\n" - f" Bounding box: {bbox_str}\n" - "----------------------------------\n" - ) + try: + for i, obj in enumerate(objects): + pos = obj["position"] + rot = obj["rotation"] + size = obj["size"] + bbox = obj["bbox"] + + # Format each object with a multiline f-string for better readability + bbox_str = f"[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]" + formatted_data += ( + f"Object {i + 1}: {obj['label']}\n" + f" ID: {obj['object_id']}\n" + f" Confidence: {obj['confidence']:.2f}\n" + f" Position: x={pos.x:.2f}m, y={pos.y:.2f}m, z={pos.z:.2f}m\n" + f" Rotation: yaw={rot.z:.2f} rad\n" + f" Size: width={size['width']:.2f}m, height={size['height']:.2f}m\n" + f" Depth: {obj['depth']:.2f}m\n" + f" Bounding box: {bbox_str}\n" + "----------------------------------\n" + ) + except Exception as e: + logger.warning(f"Error formatting object {i}: {e}") + formatted_data += f"Object {i + 1}: [Error formatting data]" + formatted_data += "\n----------------------------------\n" return formatted_data diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py index aaf4b57083..90022023ca 100644 --- a/dimos/robot/robot.py +++ b/dimos/robot/robot.py @@ -21,10 +21,12 @@ from abc import ABC import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, List, Union, Dict, Any from dimos.hardware.interface import HardwareInterface from dimos.perception.spatial_perception import SpatialMemory +from dimos.manipulation.manipulation_interface import ManipulationInterface +from dimos.types.robot_capabilities import RobotCapability from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -38,6 +40,8 @@ from reactivex.scheduler import ThreadPoolScheduler from dimos.utils.threadpool import get_scheduler +from dimos.utils.reactive import backpressure +from dimos.stream.video_provider import VideoProvider logger = setup_logger("dimos.robot.robot") @@ -65,9 +69,10 @@ def __init__( output_dir: str = os.path.join(os.getcwd(), "assets", "output"), pool_scheduler: ThreadPoolScheduler = None, skill_library: SkillLibrary = None, - spatial_memory_dir: str = None, spatial_memory_collection: str = "spatial_memory", new_memory: bool = False, + capabilities: List[RobotCapability] = None, + video_stream: Optional[Observable] = None, ): """Initialize a Robot instance. @@ -77,7 +82,6 @@ def __init__( output_dir: Directory for storing output files. Defaults to "./assets/output". pool_scheduler: Thread pool scheduler. If None, one will be created. skill_library: Skill library instance. If None, one will be created. - spatial_memory_dir: Directory for storing spatial memory data. If None, uses output_dir/spatial_memory. spatial_memory_collection: Name of the collection in the ChromaDB database. new_memory: If True, creates a new spatial memory from scratch. Defaults to False. """ @@ -88,16 +92,19 @@ def __init__( self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() self.skill_library = skill_library if skill_library else SkillLibrary() - # Create output directory if it doesn't exist - os.makedirs(self.output_dir, exist_ok=True) + # Initialize robot capabilities + self.capabilities = capabilities or [] # Create output directory if it doesn't exist + os.makedirs(self.output_dir, exist_ok=True) logger.info(f"Robot outputs will be saved to: {self.output_dir}") + # Initialize memory properties + self.memory_dir = os.path.join(self.output_dir, "memory") + os.makedirs(self.memory_dir, exist_ok=True) + # Initialize spatial memory properties - self.spatial_memory_dir = spatial_memory_dir or os.path.join( - self.output_dir, "spatial_memory" - ) + self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") self.spatial_memory_collection = spatial_memory_collection self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") @@ -106,17 +113,14 @@ def __init__( os.makedirs(self.spatial_memory_dir, exist_ok=True) os.makedirs(self.db_path, exist_ok=True) - # Import SpatialMemory here to avoid circular imports - from dimos.perception.spatial_perception import SpatialMemory - # Initialize spatial memory - this will be handled by SpatialMemory class - video_stream = None + self._video_stream = video_stream transform_provider = None # Only create video stream if ROS control is available if self.ros_control is not None and self.ros_control.video_provider is not None: # Get video stream - video_stream = self.get_ros_video_stream(fps=10) # Lower FPS for processing + self._video_stream = self.get_ros_video_stream(fps=10) # Lower FPS for processing # Define transform provider def transform_provider(): @@ -125,6 +129,9 @@ def transform_provider(): return {"position": None, "rotation": None} return {"position": position, "rotation": rotation} + # Avoids circular imports + from dimos.perception.spatial_perception import SpatialMemory + # Create SpatialMemory instance - it will handle all initialization internally self._spatial_memory = SpatialMemory( collection_name=self.spatial_memory_collection, @@ -132,10 +139,25 @@ def transform_provider(): visual_memory_path=self.visual_memory_path, new_memory=new_memory, output_dir=self.spatial_memory_dir, - video_stream=video_stream, + video_stream=self._video_stream, transform_provider=transform_provider, ) + # Initialize manipulation interface if the robot has manipulation capability + self._manipulation_interface = None + if RobotCapability.MANIPULATION in self.capabilities: + # Initialize manipulation memory properties if the robot has manipulation capability + self.manipulation_memory_dir = os.path.join(self.memory_dir, "manipulation_memory") + + # Create manipulation memory directory + os.makedirs(self.manipulation_memory_dir, exist_ok=True) + + self._manipulation_interface = ManipulationInterface( + output_dir=self.output_dir, # Use the main output directory + new_memory=new_memory, + ) + logger.info("Manipulation interface initialized") + def get_ros_video_stream(self, fps: int = 30) -> Observable: """Get the ROS video stream with rate limiting and frame processing. @@ -323,14 +345,53 @@ def set_hardware_configuration(self, configuration): """ self.hardware_interface.set_configuration(configuration) + @property + def spatial_memory(self) -> SpatialMemory: + """Get the robot's spatial memory. + + Returns: + SpatialMemory: The robot's spatial memory system. + """ + return self._spatial_memory + + @property + def manipulation_interface(self) -> Optional[ManipulationInterface]: + """Get the robot's manipulation interface. + + Returns: + ManipulationInterface: The robot's manipulation interface or None if not available. + """ + return self._manipulation_interface + + def has_capability(self, capability: RobotCapability) -> bool: + """Check if the robot has a specific capability. + + Args: + capability: The capability to check for + + Returns: + bool: True if the robot has the capability, False otherwise + """ + return capability in self.capabilities + def get_spatial_memory(self) -> Optional[SpatialMemory]: """Simple getter for the spatial memory instance. + (For backwards compatibility) Returns: The spatial memory instance or None if not set. """ return self._spatial_memory if self._spatial_memory else None + @property + def video_stream(self) -> Optional[Observable]: + """Get the robot's video stream. + + Returns: + Observable: The robot's video stream or None if not available. + """ + return self._video_stream + def cleanup(self): """Clean up resources used by the robot. @@ -356,3 +417,20 @@ def __init__(self): def my_print(self): print("Hello, world!") + + +class MockManipulationRobot(Robot): + def __init__(self, skill_library: Optional[SkillLibrary] = None): + video_provider = VideoProvider("webcam", video_source=0) # Default camera + video_stream = backpressure( + video_provider.capture_video_as_observable(realtime=True, fps=30) + ) + + super().__init__( + capabilities=[RobotCapability.MANIPULATION], + video_stream=video_stream, + skill_library=skill_library, + ) + self.camera_intrinsics = [489.33, 367.0, 320.0, 240.0] + self.ros_control = None + self.hardware_interface = None diff --git a/dimos/robot/unitree/unitree_go2.py b/dimos/robot/unitree/unitree_go2.py index 7f1d760b34..8667edfcf8 100644 --- a/dimos/robot/unitree/unitree_go2.py +++ b/dimos/robot/unitree/unitree_go2.py @@ -53,8 +53,6 @@ def __init__( disable_video_stream: bool = False, mock_connection: bool = False, skills: Optional[Union[MyUnitreeSkills, AbstractSkill]] = None, - spatial_memory_dir: str = None, - spatial_memory_collection: str = "spatial_memory", new_memory: bool = False, ): """Initialize the UnitreeGo2 robot. @@ -94,8 +92,6 @@ def __init__( ros_control=ros_control, output_dir=output_dir, skill_library=skills, - spatial_memory_dir=spatial_memory_dir, - spatial_memory_collection=spatial_memory_collection, new_memory=new_memory, ) diff --git a/dimos/skills/manipulation/abstract_manipulation_skill.py b/dimos/skills/manipulation/abstract_manipulation_skill.py new file mode 100644 index 0000000000..8881548540 --- /dev/null +++ b/dimos/skills/manipulation/abstract_manipulation_skill.py @@ -0,0 +1,60 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract base class for manipulation skills.""" + +from typing import Optional + +from dimos.skills.skills import AbstractRobotSkill, Colors +from dimos.robot.robot import Robot +from dimos.manipulation.manipulation_interface import ManipulationInterface +from dimos.types.robot_capabilities import RobotCapability + + +class AbstractManipulationSkill(AbstractRobotSkill): + """Base class for all manipulation-related skills. + + This abstract class provides access to the robot's manipulation memory system. + """ + + def __init__(self, *args, robot: Optional[Robot] = None, **kwargs): + """Initialize the manipulation skill. + + Args: + robot: The robot instance to associate with this skill + """ + super().__init__(*args, robot=robot, **kwargs) + + if self._robot and not self._robot.manipulation_interface: + raise NotImplementedError( + "This robot does not have a manipulation interface implemented" + ) + + @property + def manipulation_interface(self) -> Optional[ManipulationInterface]: + """Get the robot's manipulation interface. + + Returns: + ManipulationInterface: The robot's manipulation interface or None if not available + + Raises: + RuntimeError: If the robot doesn't have the MANIPULATION capability + """ + if self._robot is None: + return None + + if not self._robot.has_capability(RobotCapability.MANIPULATION): + raise RuntimeError("This robot does not have manipulation capabilities") + + return self._robot.manipulation_interface diff --git a/dimos/skills/manipulation/force_constraint_skill.py b/dimos/skills/manipulation/force_constraint_skill.py new file mode 100644 index 0000000000..d7a97287b2 --- /dev/null +++ b/dimos/skills/manipulation/force_constraint_skill.py @@ -0,0 +1,73 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Tuple +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.skills.skills import AbstractRobotSkill +from dimos.types.manipulation import ForceConstraint, Vector +from dimos.utils.logging_config import setup_logger + +# Initialize logger +logger = setup_logger("dimos.skills.force_constraint_skill") + + +class ForceConstraintSkill(AbstractManipulationSkill): + """ + Skill for generating force constraints for robot manipulation. + + This skill generates force constraints and adds them to the ManipulationInterface's + agent_constraints list for tracking constraints created by the Agent. + """ + + # Constraint parameters + min_force: float = Field(0.0, description="Minimum force magnitude in Newtons") + max_force: float = Field(100.0, description="Maximum force magnitude in Newtons to apply") + + # Force direction as (x,y) tuple + force_direction: Optional[Tuple[float, float]] = Field( + None, description="Force direction vector (x,y)" + ) + + # Description + description: str = Field("", description="Description of the force constraint") + + def __call__(self) -> ForceConstraint: + """ + Generate a force constraint based on the parameters. + + Returns: + ForceConstraint: The generated constraint + """ + # Create force direction vector if provided (convert 2D point to 3D vector with z=0) + force_direction_vector = None + if self.force_direction: + force_direction_vector = Vector(self.force_direction[0], self.force_direction[1], 0.0) + + # Create and return the constraint + constraint = ForceConstraint( + max_force=self.max_force, + min_force=self.min_force, + force_direction=force_direction_vector, + description=self.description, + ) + + # Add constraint to manipulation interface for Agent recall + self.manipulation_interface.add_constraint(constraint) + + # Log the constraint creation + logger.info(f"Generated force constraint: {self.description}") + + return constraint diff --git a/dimos/skills/manipulation/manipulate_skill.py b/dimos/skills/manipulation/manipulate_skill.py new file mode 100644 index 0000000000..efd923f8c6 --- /dev/null +++ b/dimos/skills/manipulation/manipulate_skill.py @@ -0,0 +1,176 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Dict, Any, Optional, Union +import time +import uuid + +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.types.manipulation import ( + AbstractConstraint, + TranslationConstraint, + RotationConstraint, + ForceConstraint, + ManipulationTaskConstraint, + ManipulationTask, + ManipulationMetadata, +) +from dimos.utils.logging_config import setup_logger + +# Initialize logger +logger = setup_logger("dimos.skills.manipulate_skill") + + +class Manipulate(AbstractManipulationSkill): + """ + Skill for executing manipulation tasks with constraints. + Can be called by an LLM with a list of manipulation constraints. + """ + + description: str = Field("", description="Description of the manipulation task") + + # Target object information + target_object: str = Field( + "", description="Semantic label of the target object (e.g., 'cup', 'box')" + ) + + target_point: str = Field( + "", description="(X,Y) point in pixel-space of the point to manipulate on target object" + ) + + # Constraints - can be set directly + constraints: List[str] = Field( + [], + description="List of AbstractConstraint constraint IDs from AgentMemory to apply to the manipulation task", + ) + + # Object movement tolerances + object_tolerances: Dict[str, float] = Field( + {}, # Empty dict as default + description="Dictionary mapping object IDs to movement tolerances (0.0 = immovable, 1.0 = freely movable)", + ) + + def __call__(self) -> Dict[str, Any]: + """ + Execute a manipulation task with the given constraints. + + Returns: + Dict[str, Any]: Result of the manipulation operation + """ + # Get the manipulation constraint + constraint = self._build_manipulation_constraint() + + # Create task with unique ID + task_id = f"{str(uuid.uuid4())[:4]}" + timestamp = time.time() + + # Build metadata with environment state + metadata = self._build_manipulation_metadata() + + task = ManipulationTask( + description=self.description, + target_object=self.target_object, + target_point=tuple(map(int, self.target_point.strip("()").split(","))), + constraints=constraint, + metadata=metadata, + timestamp=timestamp, + task_id=task_id, + result=None, + ) + + # Add task to manipulation interface + self.manipulation_interface.add_manipulation_task(task) + + # Execute the manipulation + result = self._execute_manipulation(task) + + # Log the execution + logger.info( + f"Executed manipulation '{self.description}' with constraints: {self.constraints}" + ) + + return result + + def _build_manipulation_metadata(self) -> ManipulationMetadata: + """ + Build metadata for the current environment state, including object data and movement tolerances. + """ + # Get detected objects from the manipulation interface + detected_objects = [] + try: + detected_objects = self.manipulation_interface.get_latest_objects() or [] + except Exception as e: + logger.warning(f"Failed to get detected objects: {e}") + + # Create dictionary of objects keyed by ID for easier lookup + objects_by_id = {} + for obj in detected_objects: + obj_id = str(obj.get("object_id", -1)) + objects_by_id[obj_id] = dict(obj) # Make a copy to avoid modifying original + + # Create objects_data dictionary with tolerances applied + objects_data: Dict[str, Any] = {} + + # First, apply all specified tolerances + for object_id, tolerance in self.object_tolerances.items(): + if object_id in objects_by_id: + # Object exists in detected objects, update its tolerance + obj_data = objects_by_id[object_id] + obj_data["movement_tolerance"] = tolerance + objects_data[object_id] = obj_data + + # Add any detected objects not explicitly given tolerances + for obj_id, obj in objects_by_id.items(): + if obj_id not in self.object_tolerances: + obj["movement_tolerance"] = 0.0 # Default to immovable + objects_data[obj_id] = obj + + # Create properly typed ManipulationMetadata + metadata: ManipulationMetadata = {"timestamp": time.time(), "objects": objects_data} + + return metadata + + def _build_manipulation_constraint(self) -> ManipulationTaskConstraint: + """ + Build a ManipulationTaskConstraint object from the provided parameters. + """ + + constraint = ManipulationTaskConstraint() + + # Add constraints directly or resolve from IDs + for c in self.constraints: + if isinstance(c, AbstractConstraint): + constraint.add_constraint(c) + elif isinstance(c, str) and self.manipulation_interface: + # Try to load constraint from ID + saved_constraint = self.manipulation_interface.get_constraint(c) + if saved_constraint: + constraint.add_constraint(saved_constraint) + + return constraint + + # TODO: Implement + def _execute_manipulation(self, task: ManipulationTask) -> Dict[str, Any]: + """ + Execute the manipulation with the given constraint. + + Args: + task: The manipulation task to execute + + Returns: + Dict[str, Any]: Result of the manipulation operation + """ + return {"success": True} diff --git a/dimos/skills/manipulation/rotation_constraint_skill.py b/dimos/skills/manipulation/rotation_constraint_skill.py new file mode 100644 index 0000000000..a4973bf64d --- /dev/null +++ b/dimos/skills/manipulation/rotation_constraint_skill.py @@ -0,0 +1,108 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Dict, Any, Optional, Tuple, Literal +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.types.manipulation import RotationConstraint +from dimos.utils.logging_config import setup_logger +from dimos.types.vector import Vector + +# Initialize logger +logger = setup_logger("dimos.skills.rotation_constraint_skill") + + +class RotationConstraintSkill(AbstractManipulationSkill): + """ + Skill for generating rotation constraints for robot manipulation. + + This skill generates rotation constraints and adds them to the ManipulationInterface's + agent_constraints list for tracking constraints created by the Agent. + """ + + # Rotation axis parameter + rotation_axis: Literal["roll", "pitch", "yaw"] = Field( + "roll", + description="Axis to rotate around: 'roll' (x-axis), 'pitch' (y-axis), or 'yaw' (z-axis)", + ) + + # Simple angle values for rotation (in degrees) + start_angle: Optional[float] = Field(None, description="Starting angle in degrees") + end_angle: Optional[float] = Field(None, description="Ending angle in degrees") + + # Pivot points as (x,y) tuples + pivot_point: Optional[Tuple[float, float]] = Field( + None, description="Pivot point (x,y) for rotation" + ) + + # TODO: Secondary pivot point for more complex rotations + secondary_pivot_point: Optional[Tuple[float, float]] = Field( + None, description="Secondary pivot point (x,y) for double-pivot rotation" + ) + + def __call__(self) -> RotationConstraint: + """ + Generate a rotation constraint based on the parameters. + + This implementation supports rotation around a single axis (roll, pitch, or yaw). + + Returns: + RotationConstraint: The generated constraint + """ + # rotation_axis is guaranteed to be one of "roll", "pitch", or "yaw" due to Literal type constraint + + # Create angle vectors more efficiently + start_angle_vector = None + if self.start_angle is not None: + # Build rotation vector on correct axis + values = [0.0, 0.0, 0.0] + axis_index = {"roll": 0, "pitch": 1, "yaw": 2}[self.rotation_axis] + values[axis_index] = self.start_angle + start_angle_vector = Vector(*values) + + end_angle_vector = None + if self.end_angle is not None: + values = [0.0, 0.0, 0.0] + axis_index = {"roll": 0, "pitch": 1, "yaw": 2}[self.rotation_axis] + values[axis_index] = self.end_angle + end_angle_vector = Vector(*values) + + # Create pivot point vector if provided (convert 2D point to 3D vector with z=0) + pivot_point_vector = None + if self.pivot_point: + pivot_point_vector = Vector(self.pivot_point[0], self.pivot_point[1], 0.0) + + # Create secondary pivot point vector if provided + secondary_pivot_vector = None + if self.secondary_pivot_point: + secondary_pivot_vector = Vector( + self.secondary_pivot_point[0], self.secondary_pivot_point[1], 0.0 + ) + + constraint = RotationConstraint( + rotation_axis=self.rotation_axis, + start_angle=start_angle_vector, + end_angle=end_angle_vector, + pivot_point=pivot_point_vector, + secondary_pivot_point=secondary_pivot_vector, + ) + + # Add constraint to manipulation interface + self.manipulation_interface.add_constraint(constraint) + + # Log the constraint creation + logger.info(f"Generated rotation constraint around {self.rotation_axis} axis") + + return constraint diff --git a/dimos/skills/manipulation/translation_constraint_skill.py b/dimos/skills/manipulation/translation_constraint_skill.py new file mode 100644 index 0000000000..69c9f128e0 --- /dev/null +++ b/dimos/skills/manipulation/translation_constraint_skill.py @@ -0,0 +1,100 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Tuple, Literal +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.skills.skills import AbstractRobotSkill +from dimos.types.manipulation import TranslationConstraint, Vector +from dimos.utils.logging_config import setup_logger + +# Initialize logger +logger = setup_logger("dimos.skills.translation_constraint_skill") + + +class TranslationConstraintSkill(AbstractManipulationSkill): + """ + Skill for generating translation constraints for robot manipulation. + + This skill generates translation constraints and adds them to the ManipulationInterface's + agent_constraints list for tracking constraints created by the Agent. + """ + + # Constraint parameters + translation_axis: Literal["x", "y", "z"] = Field( + "x", description="Axis to translate along: 'x', 'y', or 'z'" + ) + + reference_point: Optional[Tuple[float, float]] = Field( + None, description="Reference point (x,y) on the target object for translation constraining" + ) + + bounds_min: Optional[Tuple[float, float]] = Field( + None, description="Minimum bounds (x,y) for bounded translation" + ) + + bounds_max: Optional[Tuple[float, float]] = Field( + None, description="Maximum bounds (x,y) for bounded translation" + ) + + target_point: Optional[Tuple[float, float]] = Field( + None, description="Final target position (x,y) for translation constraining" + ) + + # Description + description: str = Field("", description="Description of the translation constraint") + + def __call__(self) -> TranslationConstraint: + """ + Generate a translation constraint based on the parameters. + + Returns: + TranslationConstraint: The generated constraint + """ + # Create reference point vector if provided (convert 2D point to 3D vector with z=0) + reference_point = None + if self.reference_point: + reference_point = Vector(self.reference_point[0], self.reference_point[1], 0.0) + + # Create bounds minimum vector if provided + bounds_min = None + if self.bounds_min: + bounds_min = Vector(self.bounds_min[0], self.bounds_min[1], 0.0) + + # Create bounds maximum vector if provided + bounds_max = None + if self.bounds_max: + bounds_max = Vector(self.bounds_max[0], self.bounds_max[1], 0.0) + + # Create relative target vector if provided + target_point = None + if self.target_point: + target_point = Vector(self.target_point[0], self.target_point[1], 0.0) + + constraint = TranslationConstraint( + translation_axis=self.translation_axis, + reference_point=reference_point, + bounds_min=bounds_min, + bounds_max=bounds_max, + target_point=target_point, + ) + + # Add constraint to manipulation interface + self.manipulation_interface.add_constraint(constraint) + + # Log the constraint creation + logger.info(f"Generated translation constraint along {self.translation_axis} axis") + + return {"success": True} diff --git a/dimos/types/manipulation.py b/dimos/types/manipulation.py new file mode 100644 index 0000000000..d61d73a7ed --- /dev/null +++ b/dimos/types/manipulation.py @@ -0,0 +1,155 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict, List, Optional, Any, Union, TypedDict, Tuple, Literal +from dataclasses import dataclass, field, fields +from abc import ABC, abstractmethod +import uuid +import numpy as np +import time +from dimos.types.vector import Vector + + +class ConstraintType(Enum): + """Types of manipulation constraints.""" + + TRANSLATION = "translation" + ROTATION = "rotation" + FORCE = "force" + + +@dataclass +class AbstractConstraint(ABC): + """Base class for all manipulation constraints.""" + + description: str = "" + id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) + + +@dataclass +class TranslationConstraint(AbstractConstraint): + """Constraint parameters for translational movement along a single axis.""" + + translation_axis: Literal["x", "y", "z"] = None # Axis to translate along + reference_point: Optional[Vector] = None + bounds_min: Optional[Vector] = None # For bounded translation + bounds_max: Optional[Vector] = None # For bounded translation + target_point: Optional[Vector] = None # For relative positioning + + +@dataclass +class RotationConstraint(AbstractConstraint): + """Constraint parameters for rotational movement around a single axis.""" + + rotation_axis: Literal["roll", "pitch", "yaw"] = None # Axis to rotate around + start_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis + end_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis + pivot_point: Optional[Vector] = None # Point of rotation + secondary_pivot_point: Optional[Vector] = None # For double point rotations + + +@dataclass +class ForceConstraint(AbstractConstraint): + """Constraint parameters for force application.""" + + max_force: float = 0.0 # Maximum force in newtons + min_force: float = 0.0 # Minimum force in newtons + force_direction: Optional[Vector] = None # Direction of force application + + +class ObjectData(TypedDict, total=False): + """Data about an object in the manipulation scene.""" + + object_id: int # Unique ID for the object + bbox: List[float] # Bounding box [x1, y1, x2, y2] + depth: float # Depth in meters from Metric3d + confidence: float # Detection confidence + class_id: int # Class ID from the detector + label: str # Semantic label (e.g., 'cup', 'table') + movement_tolerance: float # (0.0 = immovable, 1.0 = freely movable) + segmentation_mask: np.ndarray # Binary mask of the object's pixels + position: Dict[str, float] # 3D position {x, y, z} + rotation: Dict[str, float] # 3D rotation {roll, pitch, yaw} + size: Dict[str, float] # Object dimensions {width, height} + + +class ManipulationMetadata(TypedDict, total=False): + """Typed metadata for manipulation constraints.""" + + timestamp: float + objects: Dict[str, ObjectData] + + +@dataclass +class ManipulationTaskConstraint: + """Set of constraints for a specific manipulation action.""" + + constraints: List[AbstractConstraint] = field(default_factory=list) + + def add_constraint(self, constraint: AbstractConstraint): + """Add a constraint to this set.""" + if constraint not in self.constraints: + self.constraints.append(constraint) + + def get_constraints(self) -> List[AbstractConstraint]: + """Get all constraints in this set.""" + return self.constraints + + +@dataclass +class ManipulationTask: + """Complete definition of a manipulation task.""" + + description: str + target_object: str # Semantic label of target object + target_point: Optional[Tuple[float, float]] = ( + None # (X,Y) point in pixel-space of the point to manipulate on target object + ) + metadata: ManipulationMetadata = field(default_factory=dict) + timestamp: float = field(default_factory=time.time) + task_id: str = "" + result: Optional[Dict[str, Any]] = None # Any result data from the task execution + constraints: Union[List[AbstractConstraint], ManipulationTaskConstraint, AbstractConstraint] = ( + field(default_factory=list) + ) + + def add_constraint(self, constraint: AbstractConstraint): + """Add a constraint to this manipulation task.""" + # If constraints is a ManipulationTaskConstraint object + if isinstance(self.constraints, ManipulationTaskConstraint): + self.constraints.add_constraint(constraint) + return + + # If constraints is a single AbstractConstraint, convert to list + if isinstance(self.constraints, AbstractConstraint): + self.constraints = [self.constraints, constraint] + return + + # If constraints is a list, append to it + # This will also handle empty lists (the default case) + self.constraints.append(constraint) + + def get_constraints(self) -> List[AbstractConstraint]: + """Get all constraints in this manipulation task.""" + # If constraints is a ManipulationTaskConstraint object + if isinstance(self.constraints, ManipulationTaskConstraint): + return self.constraints.get_constraints() + + # If constraints is a single AbstractConstraint, return as list + if isinstance(self.constraints, AbstractConstraint): + return [self.constraints] + + # If constraints is a list (including empty list), return it + return self.constraints diff --git a/dimos/types/robot_capabilities.py b/dimos/types/robot_capabilities.py new file mode 100644 index 0000000000..8c9a7fcd41 --- /dev/null +++ b/dimos/types/robot_capabilities.py @@ -0,0 +1,27 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Robot capabilities module for defining robot functionality.""" + +from enum import Enum, auto + + +class RobotCapability(Enum): + """Enum defining possible robot capabilities.""" + + MANIPULATION = auto() + VISION = auto() + AUDIO = auto() + SPEECH = auto() + LOCOMOTION = auto() diff --git a/tests/test_manipulation_agent.py b/tests/test_manipulation_agent.py new file mode 100644 index 0000000000..5062fd8446 --- /dev/null +++ b/tests/test_manipulation_agent.py @@ -0,0 +1,346 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.skills.skills import SkillLibrary +import tests.test_header +import os + +import time +from dotenv import load_dotenv +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal +from dimos.skills.visual_navigation_skills import FollowHuman +import reactivex as rx +import reactivex.operators as ops +from dimos.stream.audio.pipelines import tts, stt +import threading +import json +import cv2 +import numpy as np +import os +import datetime +from dimos.types.vector import Vector +from dimos.skills.speak import Speak +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.agents.agent import OpenAIAgent +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from openai import OpenAI +from dimos.utils.reactive import backpressure +from dimos.stream.video_provider import VideoProvider +from reactivex.subject import Subject, BehaviorSubject +from dimos.utils.logging_config import setup_logger +from dimos.skills.manipulation.translation_constraint_skill import TranslationConstraintSkill +from dimos.skills.manipulation.rotation_constraint_skill import RotationConstraintSkill +from dimos.skills.manipulation.manipulate_skill import Manipulate +from dimos.robot.robot import MockManipulationRobot + +# Initialize logger for the agent module +logger = setup_logger("dimos.tests.test_manipulation_agent") + +# Load API key from environment +load_dotenv() + +# Allow command line arguments to control spatial memory parameters +import argparse + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Run the robot with optional spatial memory parameters" + ) + parser.add_argument( + "--new-memory", action="store_true", help="Create a new spatial memory from scratch" + ) + return parser.parse_args() + + +args = parse_arguments() + + +# Set up the manipulation skills library +manipulation_skills = SkillLibrary() + +robot = MockManipulationRobot(skill_library=manipulation_skills) + +# Add the skills to the library +manipulation_skills.add(TranslationConstraintSkill) +manipulation_skills.add(RotationConstraintSkill) +manipulation_skills.add(Manipulate) + +# Create instances with appropriate parameters +manipulation_skills.create_instance("TranslationConstraintSkill", robot=robot) +manipulation_skills.create_instance("RotationConstraintSkill", robot=robot) +manipulation_skills.create_instance("Manipulate", robot=robot) + + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) + +# Initialize object detection stream +detector = Detic2DDetector() + + +# Initialize test video stream +# video_stream = VideoProvider( +# dev_name="UnitreeGo2", +# video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov" +# ).capture_video_as_observable(realtime=False, fps=1) + +# Initialize ObjectDetectionStream with robot +object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + detector=detector, + video_stream=robot.video_stream, + disable_depth=True, +) + +# Create visualization stream for web interface (detection visualization) +viz_stream = backpressure(object_detector.get_stream()).pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), +) + + +# Helper function to draw a manipulation point on a frame +def draw_point_on_frame(frame, x, y): + # Draw a circle at the manipulation point + cv2.circle(frame, (x, y), 10, (0, 0, 255), -1) # Red circle + cv2.circle(frame, (x, y), 12, (255, 255, 255), 2) # White border + + # Add text with coordinates + cv2.putText( + frame, f"({x},{y})", (x + 15, y + 15), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2 + ) + + return frame + + +# Function to add manipulation point to stream frames +def draw_manipulation_point(frame): + try: + if frame is None or latest_manipulation_point.value is None: + return frame + + # Make a copy to avoid modifying the original frame + viz_frame = frame.copy() + + # Get the latest manipulation point coordinates + x, y = latest_manipulation_point.value + + # Draw the point using our helper function + draw_point_on_frame(viz_frame, x, y) + + return viz_frame + except Exception as e: + logger.error(f"Error drawing manipulation point: {e}") + return frame + + +# Create manipulation point visualization stream +manipulation_viz_stream = robot.video_stream.pipe( + ops.map(draw_manipulation_point), ops.filter(lambda x: x is not None), ops.share() +) + +# Get the formatted detection stream +formatted_detection_stream = object_detector.get_formatted_stream().pipe( + ops.filter(lambda x: x is not None) +) + + +# Create a direct mapping that combines detection data with locations +def combine_with_locations(object_detections): + # Get locations from spatial memory + try: + locations = robot.get_spatial_memory().get_robot_locations() + + # Format the locations section + locations_text = "\n\nSaved Robot Locations:\n" + if locations: + for loc in locations: + locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " + locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" + else: + locations_text += "None\n" + + # Simply concatenate the strings + return object_detections + locations_text + except Exception as e: + print(f"Error adding locations: {e}") + return object_detections + + +# Create the combined stream with a simple pipe operation +enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) + +streams = { + "unitree_video": robot.video_stream, + "object_detection": viz_stream, + "manipulation_point": manipulation_viz_stream, +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +# stt_node = stt() + +# Read system query from prompt.txt file +with open( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt"), "r" +) as f: + system_query = f.read() + + +# Create response subject +response_subject = rx.subject.Subject() + +# Create behavior subjects to store the current frame and latest manipulation point +# BehaviorSubject stores the latest value and provides it to new subscribers +current_frame_subject = BehaviorSubject(None) +latest_manipulation_point = BehaviorSubject(None) # Will store (x, y) tuple + + +# Function to parse manipulation point coordinates from VLM response +def process_manipulation_point(response, frame): + logger.info(f"Processing manipulation point with response: {response}") + try: + # Parse coordinates from response (format: "x,y") + coords = response.strip().split(",") + if len(coords) != 2: + logger.error(f"Invalid coordinate format: {response}") + return + + x, y = int(coords[0]), int(coords[1]) + + # Update the latest manipulation point subject with the new coordinates + latest_manipulation_point.on_next((x, y)) + + # Save a static image with the point for reference + save_manipulation_point_image(frame, x, y) + + except Exception as e: + logger.error(f"Error processing manipulation point: {e}") + + +# Function to save a static image with manipulation point visualization +def save_manipulation_point_image(frame, x, y): + try: + if frame is None: + logger.error("Cannot save manipulation point image: frame is None") + return + + # Create a copy of the frame for static image saving + visualization = frame.copy() + + # Draw the manipulation point + draw_point_on_frame(visualization, x, y) + + # Create directory if it doesn't exist + output_dir = os.path.join(os.getcwd(), "assets", "agent", "manipulation_agent") + os.makedirs(output_dir, exist_ok=True) + + # Save image with timestamp + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = os.path.join(output_dir, f"manipulation_point_{timestamp}.jpg") + cv2.imwrite(output_path, visualization) + + logger.info(f"Saved manipulation point visualization to {output_path}") + except Exception as e: + logger.error(f"Error saving manipulation point image: {e}") + + +# Subscribe to video stream to capture current frame +# Use `current_frame_subject` BehaviorSubject to store the latest frame for manipulation point visualization +robot.video_stream.subscribe( + on_next=lambda frame: current_frame_subject.on_next( + frame.copy() if frame is not None else None + ), + on_error=lambda error: logger.error(f"Error in video stream: {error}"), +) + +# Create Qwen client +qwen_client = OpenAI( + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key=os.getenv("ALIBABA_API_KEY"), +) + +# Create temporary agent for processing +manipulation_vlm = ClaudeAgent( + dev_name="QwenSingleFrameAgent", + # openai_client=qwen_client, + # model_name="qwen2.5-vl-72b-instruct", + # tokenizer=HuggingFaceTokenizer(model_name=f"Qwen/qwen2.5-vl-72b-instruct"), + # max_output_tokens_per_request=100, + system_query="You are a robot that is trying to perform a manipulation task. ", + # input_video_stream=robot.video_stream, + skills=manipulation_skills, + input_query_stream=web_interface.query_stream, + # input_data_stream=enhanced_data_stream, +) + +# # Subscribe to VLM responses to process manipulation points +# manipulation_vlm.get_response_observable().subscribe( +# on_next=lambda response: process_manipulation_point(response, current_frame_subject.value), +# on_error=lambda error: logger.error(f"Error in VLM response stream: {error}"), +# ) + + +# Create a ClaudeAgent instance +# manipulation_agent = ClaudeAgent( +# dev_name="test_agent", +# # input_query_stream=stt_node.emit_text(), +# input_query_stream=manipulation_vlm.get_response_observable(), +# input_data_stream=enhanced_data_stream, # Add the enhanced data stream +# skills=robot.get_skills(), +# system_query="system_query", +# model_name="claude-3-7-sonnet-latest", +# thinking_budget_tokens=0 +# ) + +# tts_node = tts() +# tts_node.consume_text(agent.get_response_observable()) + +# robot_skills = robot.get_skills() +# robot_skills.add(ObserveStream) +# robot_skills.add(KillSkill) +# robot_skills.add(NavigateWithText) +# robot_skills.add(FollowHuman) +# robot_skills.add(GetPose) +# # robot_skills.add(Speak) +# robot_skills.add(NavigateToGoal) +# robot_skills.create_instance("ObserveStream", robot=robot, agent=manipulation_agent) +# robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +# robot_skills.create_instance("NavigateWithText", robot=robot) +# robot_skills.create_instance("FollowHuman", robot=robot) +# robot_skills.create_instance("GetPose", robot=robot) +# robot_skills.create_instance("NavigateToGoal", robot=robot) +# robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +# manipulation_agent.get_response_observable().subscribe( +# lambda x: agent_response_subject.on_next(x) +# ) + + +web_interface.run() diff --git a/tests/test_object_detection_stream.py b/tests/test_object_detection_stream.py index 261f59382a..63777123ea 100644 --- a/tests/test_object_detection_stream.py +++ b/tests/test_object_detection_stream.py @@ -72,8 +72,8 @@ def print_results(self, objects: List[Dict[str, Any]]): print( f"{i + 1}. {obj['label']} (ID: {obj['object_id']}, Conf: {obj['confidence']:.2f})" ) - print(f" Position: x={pos['x']:.2f}, y={pos['y']:.2f}, z={pos['z']:.2f} m") - print(f" Rotation: yaw={rot['yaw']:.2f} rad") + print(f" Position: x={pos.x:.2f}, y={pos.y:.2f}, z={pos.z:.2f} m") + print(f" Rotation: yaw={rot.z:.2f} rad") print(f" Size: width={size['width']:.2f}, height={size['height']:.2f} m") print(f" Depth: {obj['depth']:.2f} m") print("-" * 30) @@ -121,6 +121,7 @@ def main(): transform_to_map=robot.ros_control.transform_pose, detector=detector, video_stream=video_stream, + disable_depth=True, ) else: # webcam mode @@ -155,6 +156,7 @@ def main(): class_filter=class_filter, detector=detector, video_stream=video_stream, + disable_depth=True, ) # Set placeholder robot for cleanup diff --git a/tests/test_unitree_ros_v0.0.4.py b/tests/test_unitree_ros_v0.0.4.py index 79f47dfef0..e4086074cc 100644 --- a/tests/test_unitree_ros_v0.0.4.py +++ b/tests/test_unitree_ros_v0.0.4.py @@ -48,9 +48,6 @@ def parse_arguments(): parser = argparse.ArgumentParser( description="Run the robot with optional spatial memory parameters" ) - parser.add_argument( - "--spatial-memory-dir", type=str, help="Directory for storing spatial memory data" - ) parser.add_argument( "--voice", action="store_true", @@ -66,7 +63,6 @@ def parse_arguments(): ip=os.getenv("ROBOT_IP"), skills=MyUnitreeSkills(), mock_connection=False, - spatial_memory_dir=args.spatial_memory_dir, # Will use default if None new_memory=True, )