diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index dfa6228492..3196c823e5 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -10,8 +10,15 @@
}
},
"containerEnv": {
- "PYTHONPATH": "${localEnv:PYTHONPATH}:/workspaces/dimos"
+ "PYTHONPATH": "${localEnv:PYTHONPATH}:/workspaces/dimos",
+ "DISPLAY": "${localEnv:DISPLAY}",
+ "WAYLAND_DISPLAY": "${localEnv:WAYLAND_DISPLAY}",
+ "XDG_RUNTIME_DIR": "${localEnv:XDG_RUNTIME_DIR}"
},
+ "mounts": [
+ "source=/tmp/.X11-unix,target=/tmp/.X11-unix,type=bind",
+ "source=${localEnv:XDG_RUNTIME_DIR},target=${localEnv:XDG_RUNTIME_DIR},type=bind"
+ ],
"postCreateCommand": "git config --global --add safe.directory /workspaces/dimos && cd /workspaces/dimos && pre-commit install",
"settings": {
"notebook.formatOnSave.enabled": true,
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index f9afc96de5..5dc19917e5 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -205,6 +205,21 @@ jobs:
cmd: "pytest -m lcm"
dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }}
+ run-integration-tests:
+ needs: [check-changes, dev]
+ if: always()
+ uses: ./.github/workflows/tests.yml
+ secrets: inherit
+ with:
+ should-run: ${{
+ needs.check-changes.result == 'success' &&
+ ((needs.dev.result == 'success') ||
+ (needs.dev.result == 'skipped' &&
+ needs.check-changes.outputs.tests == 'true'))
+ }}
+ cmd: "pytest -m integration"
+ dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }}
+
run-mypy:
needs: [check-changes, ros-dev]
if: always()
diff --git a/.gitignore b/.gitignore
index 0e6781f3a8..e52d08ba32 100644
--- a/.gitignore
+++ b/.gitignore
@@ -63,3 +63,4 @@ yolo11n.pt
/.mypy_cache*
*mobileclip*
+/results
diff --git a/README.md b/README.md
index 9a74d63aa7..9ccc7de197 100644
--- a/README.md
+++ b/README.md
@@ -1,502 +1,257 @@
-
-
-
-
-
-
- A simple two-shot PlanningAgent
-
-
-
- 3rd person POV
-
-
-
-
-
-# The Dimensional Framework
-*The universal framework for AI-native generalist robotics*
-
-## What is Dimensional?
-
-Dimensional is an open-source framework for building agentive generalist robots. DimOS allows off-the-shelf Agents to call tools/functions and read sensor/state data directly from ROS.
-
-The framework enables neurosymbolic orchestration of Agents as generalized spatial reasoners/planners and Robot state/action primitives as functions.
+
+ Program Atoms
+ The Agentive Operating System for Generalist Robotics
+
+
+
+
+[](https://discord.gg/8m6HMArf)
+[](https://github.com/dimensionalOS/dimos/stargazers)
+[](https://github.com/dimensionalOS/dimos/fork)
+[](https://github.com/dimensionalOS/dimos/graphs/contributors)
+
+
+
+
+[](https://www.docker.com/)
+
+
+ Features •
+ Installation •
+ Documentation •
+ Development •
+ Contributing
+
-The result: cross-embodied *"Dimensional Applications"* exceptional at generalization and robust at symbolic action execution.
-
-## DIMOS x Unitree Go2 (OUT OF DATE)
-
-We are shipping a first look at the DIMOS x Unitree Go2 integration, allowing for off-the-shelf Agents() to "call" Unitree ROS2 Nodes and WebRTC action primitives, including:
-
-- Navigation control primitives (move, reverse, spinLeft, spinRight, etc.)
-- WebRTC control primitives (FrontPounce, FrontFlip, FrontJump, etc.)
-- Camera feeds (image_raw, compressed_image, etc.)
-- IMU data
-- State information
-- Lidar / PointCloud primitives
-- Any other generic Unitree ROS2 topics
-
-### Features
-
-- **DimOS Agents**
- - Agent() classes with planning, spatial reasoning, and Robot.Skill() function calling abilities.
- - Integrate with any off-the-shelf hosted or local model: OpenAIAgent, ClaudeAgent, GeminiAgent 🚧, DeepSeekAgent 🚧, HuggingFaceRemoteAgent, HuggingFaceLocalAgent, etc.
- - Modular agent architecture for easy extensibility and chaining of Agent output --> Subagents input.
- - Agent spatial / language memory for location grounded reasoning and recall.
+
-- **DimOS Infrastructure**
- - A reactive data streaming architecture using RxPY to manage real-time video (or other sensor input), outbound commands, and inbound robot state between the DimOS interface, Agents, and ROS2.
- - Robot Command Queue to handle complex multi-step actions to Robot.
- - Simulation bindings (Genesis, Isaacsim, etc.) to test your agentive application before deploying to a physical robot.
+> \[!NOTE]
+>
+> ⚠️ **Alpha Pre-Release: Expect Breaking Changes** ⚠️
-- **DimOS Interface / Development Tools**
- - Local development interface to control your robot, orchestrate agents, visualize camera/lidar streams, and debug your dimensional agentive application.
+# The Dimensional Framework
-## MacOS Installation
+Dimensional is the open-source, universal operating system for generalist robotics. On DimOS, developers
+can design, build, and run physical ("dimensional") applications that run on any humanoid, quadruped,
+drone, or wheeled embodiment.
+
+**Programming physical robots is now as simple as programming digital software**: Composable, Modular, Repeatable.
+
+Core Features:
+- **Navigation:** Production navigation stack for any robot with lidar: SLAM, terrain analysis, collision
+ avoidance, route planning, exploration.
+- **Dashboard:** The DimOS command center gives developers the tooling to debug, visualize, compose, and
+ test dimensional applications in real-time. Control your robot via waypoint, agent query, keyboard,
+ VR, more.
+- **Modules:** Standalone components (equivalent to ROS nodes) that publish and subscribe to typed
+ In/Out streams that communicate over DimOS transports. The building blocks of Dimensional.
+- **Agents (experimental):** DimOS agents understand physical space, subscribe to sensor streams, and call
+ **physical** tools. Emergence appears when agents have physical agency.
+- **MCP (experimental):** Vibecode robots by giving your AI editor (Cursor, Claude Code) MCP access to run
+ physical commands (move forward 1 meter, jump, etc.).
+- **Manipulation (unreleased)** Classical (OMPL, IK, GraspGen), Agentive (TAMP), and VLA-native manipulation stack runs out-of-the-box on any DimOS supported arm embodiment.
+- **Transport/Middleware:** DimOS native Python transport supports LCM, DDS, and SHM, plus ROS 2.
+- **Robot integrations:** We integrate with the majority of hardware OEMs and are moving fast to cover
+ them all. Supported and/or immediate roadmap:
+
+ | Category | Platforms |
+ | --- | --- |
+ | Quadrupeds | Unitree Go2, Unitree B1, AGIBOT D1 Max/Pro, Dobot Rover |
+ | Drones | DJI Mavic 2, Holybro x500 |
+ | Humanoids | Unitree G1, Booster K1, AGIBOT X2, ABIBOT A2 |
+ | Arms | OpenARMs, xARM 6/7, AgileX Piper, HighTorque Pantera |
+
+# Getting Started
+
+## Installation
+
+Supported/tested matrix:
+
+| Platform | Status | Tested | Required System deps |
+| --- | --- | --- | --- |
+| Linux | supported | Ubuntu 22.04, 24.04 | See below |
+| macOS | experimental beta | not CI-tested | `brew install gnu-sed gcc portaudio git-lfs libjpeg-turbo python` |
+
+Note: macOS is usable but expect inconsistent/flaky behavior (rather than hard errors/crashes).
```sh
-# Install Nix
-curl --proto '=https' --tlsv1.2 -sSf -L https://install.determinate.systems/nix | sh -s -- install
-
-# clone the repository
-git clone --branch dev --single-branch https://github.com/dimensionalOS/dimos.git
-
-# setup the environment (follow the prompts after nix develop)
-cd dimos
-nix develop
-
-# You should be able to follow the instructions below as well for a more manual installation
-```
-
----
-## Python Installation
-Tested on Ubuntu 22.04/24.04
-
-```bash
-sudo apt install python3-venv
-
-# Clone the repository
-git clone --branch dev --single-branch https://github.com/dimensionalOS/dimos.git
-cd dimos
-
-# Create and activate virtual environment
-python3 -m venv venv
-source venv/bin/activate
-
-sudo apt install portaudio19-dev python3-pyaudio
-
-# Install LFS
-sudo apt install git-lfs
-git lfs install
-
-# Install torch and torchvision if not already installed
-# Example CUDA 11.7, Pytorch 2.0.1 (replace with your required pytorch version if different)
-pip install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
+sudo apt-get update
+sudo apt-get install -y curl g++ portaudio19-dev git-lfs libturbojpeg python3-dev
+# install uv for python
+curl -LsSf https://astral.sh/uv/install.sh | sh && export PATH="$HOME/.local/bin:$PATH"
```
-#### Install dependencies
-```bash
-# CPU only (reccomended to attempt first)
-pip install -e '.[cpu,dev]'
-
-# CUDA install
-pip install -e '.[cuda,dev]'
-
-# Copy and configure environment variables
-cp default.env .env
-```
+Option 1: Install in a virtualenv
-#### Test the install
-```bash
-pytest -s dimos/
-```
+```sh
-#### Test Dimensional with a replay UnitreeGo2 stream (no robot required)
-```bash
+uv venv && . .venv/bin/activate
+uv pip install 'dimos[base,unitree]'
+# replay recorded data to test that the system is working
+# IMPORTANT: First replay run will show a black rerun window while 2.4 GB downloads from LFS
dimos --replay run unitree-go2
```
-#### Test Dimensional with a simulated UnitreeGo2 in MuJoCo (no robot required)
-```bash
-pip install -e '.[sim]'
-export DISPLAY=:1 # Or DISPLAY=:0 if getting GLFW/OpenGL X11 errors
-dimos --simulation run unitree-go2
-```
+Option 2: Run without installing
-#### Test Dimensional with a real UnitreeGo2 over WebRTC
-```bash
-export ROBOT_IP=192.168.X.XXX # Add the robot IP address
-dimos run unitree-go2
+```sh
+uvx --from 'dimos[base,unitree]' dimos --replay run unitree-go2
```
-#### Test Dimensional with a real UnitreeGo2 running Agents
-*OpenAI / Alibaba keys required*
-```bash
-export ROBOT_IP=192.168.X.XXX # Add the robot IP address
-dimos run unitree-go2-agentic
-```
----
+
-### Agent API keys
+### Test Installation
-Full functionality will require API keys for the following:
+#### Control a robot in a simulation (no robot required)
-Requirements:
-- OpenAI API key (required for all LLMAgents due to OpenAIEmbeddings)
-- Claude API key (required for ClaudeAgent)
-- Alibaba API key (required for Navigation skills)
-These keys can be added to your .env file or exported as environment variables.
-```
-export OPENAI_API_KEY=
-export CLAUDE_API_KEY=
-export ALIBABA_API_KEY=
+```sh
+export DISPLAY=:1 # Or DISPLAY=:0 if getting GLFW/OpenGL X11 errors
+# ignore the warp warnings
+dimos --viewer-backend rerun-web --simulation run unitree-go2
```
-### ROS2 Unitree Go2 SDK Installation
-
-#### System Requirements
-- Ubuntu 22.04
-- ROS2 Distros: Iron, Humble, Rolling
-
-See [Unitree Go2 ROS2 SDK](https://github.com/dimensionalOS/go2_ros2_sdk) for additional installation instructions.
-
-```bash
-mkdir -p ros2_ws
-cd ros2_ws
-git clone --recurse-submodules https://github.com/dimensionalOS/go2_ros2_sdk.git src
-sudo apt install ros-$ROS_DISTRO-image-tools
-sudo apt install ros-$ROS_DISTRO-vision-msgs
+#### Control a real robot (Unitree Go2 over WebRTC)
-sudo apt install python3-pip clang portaudio19-dev
-cd src
-pip install -r requirements.txt
-cd ..
-
-# Ensure clean python install before running
-source /opt/ros/$ROS_DISTRO/setup.bash
-rosdep install --from-paths src --ignore-src -r -y
-colcon build
+```sh
+export ROBOT_IP=
+dimos --viewer-backend rerun-web run unitree-go2
```
-### Run the test application
+After running dimOS open http://localhost:7779 to control robot movement.
-#### ROS2 Terminal:
-```bash
-# Change path to your Go2 ROS2 SDK installation
-source /ros2_ws/install/setup.bash
-source /opt/ros/$ROS_DISTRO/setup.bash
+#### Dimensional Agents
-export ROBOT_IP="robot_ip" #for muliple robots, just split by ,
-export CONN_TYPE="webrtc"
-ros2 launch go2_robot_sdk robot.launch.py
+> \[!NOTE]
+>
+> **Experimental Beta: Potential unstoppable robot sentience**
+```sh
+export OPENAI_API_KEY=
+dimos --viewer-backend rerun-web run unitree-go2-agentic
```
-#### Python Terminal:
-```bash
-# Change path to your Go2 ROS2 SDK installation
-source /ros2_ws/install/setup.bash
-python tests/run.py
-```
+After running that, open a new terminal and run the following to start giving instructions to the agent.
+```sh
+# activate the venv in this new terminal
+source .venv/bin/activate
-#### DimOS Interface:
-```bash
-cd dimos/web/dimos_interface
-yarn install
-yarn dev # you may need to run sudo if previously built via Docker
+# then tell the agent "explore the room"
+# then tell it to go to something, ex: "go to the door"
+humancli
```
-### Project Structure (OUT OF DATE)
-
-```
-.
-├── dimos/
-│ ├── agents/ # Agent implementations
-│ │ └── memory/ # Memory systems for agents, including semantic memory
-│ ├── environment/ # Environment context and sensing
-│ ├── hardware/ # Hardware abstraction and interfaces
-│ ├── models/ # ML model definitions and implementations
-│ │ ├── Detic/ # Detic object detection model
-│ │ ├── depth/ # Depth estimation models
-│ │ ├── segmentation/ # Image segmentation models
-│ ├── perception/ # Computer vision and sensing
-│ │ ├── detection2d/ # 2D object detection
-│ │ └── segmentation/ # Image segmentation pipelines
-│ ├── robot/ # Robot control and hardware interface
-│ │ ├── global_planner/ # Path planning at global scale
-│ │ ├── local_planner/ # Local navigation planning
-│ │ └── unitree/ # Unitree Go2 specific implementations
-│ ├── simulation/ # Robot simulation environments
-│ │ ├── genesis/ # Genesis simulation integration
-│ │ └── isaac/ # NVIDIA Isaac Sim integration
-│ ├── skills/ # Task-specific robot capabilities
-│ │ └── rest/ # REST API based skills
-│ ├── stream/ # WebRTC and data streaming
-│ │ ├── audio/ # Audio streaming components
-│ │ └── video_providers/ # Video streaming components
-│ ├── types/ # Type definitions and interfaces
-│ ├── utils/ # Utility functions and helpers
-│ └── web/ # DimOS development interface and API
-│ ├── dimos_interface/ # DimOS web interface
-│ └── websocket_vis/ # Websocket visualizations
-├── tests/ # Test files
-│ ├── genesissim/ # Genesis simulator tests
-│ └── isaacsim/ # Isaac Sim tests
-└── docker/ # Docker configuration files
- ├── agent/ # Agent service containers
- ├── interface/ # Interface containers
- ├── simulation/ # Simulation environment containers
- └── unitree/ # Unitree robot specific containers
+# The Dimensional Library
+
+### Modules
+
+Modules are subsystems on a robot that operate autonomously and communicate with other subsystems using standardized messages. See below a simple robot connection module that sends streams of continuous `cmd_vel` to the robot and recieves `color_image` to a simple `Listener` module.
+
+```py
+import threading, time, numpy as np
+from dimos.core import In, Module, Out, rpc
+from dimos.core.blueprints import autoconnect
+from dimos.msgs.geometry_msgs import Twist
+from dimos.msgs.sensor_msgs import Image
+from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ImageFormat
+
+class RobotConnection(Module):
+ cmd_vel: In[Twist]
+ color_image: Out[Image]
+
+ @rpc
+ def start(self):
+ threading.Thread(target=self._image_loop, daemon=True).start()
+
+ def _image_loop(self):
+ while True:
+ img = Image.from_numpy(
+ np.zeros((120, 160, 3), np.uint8),
+ format=ImageFormat.RGB,
+ frame_id="camera_optical",
+ )
+ self.color_image.publish(img)
+ time.sleep(0.2)
+
+class Listener(Module):
+ color_image: In[Image]
+
+ @rpc
+ def start(self):
+ self.color_image.subscribe(lambda img: print(f"image {img.width}x{img.height}"))
+
+if __name__ == "__main__":
+ autoconnect(
+ RobotConnection.blueprint(),
+ Listener.blueprint(),
+ ).build().loop()
```
-## Building
+### Blueprints
-### Simple DimOS Application (OUT OF DATE)
+Blueprints are how robots are constructed on Dimensional; instructions for how to construct and wire modules. You compose them with
+`autoconnect(...)`, which connects streams by `(name, type)` and returns a `ModuleBlueprintSet`.
-```python
-from dimos.robot.unitree.unitree_go2 import UnitreeGo2
-from dimos.robot.unitree.unitree_skills import MyUnitreeSkills
-from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl
-from dimos.agents_deprecated.agent import OpenAIAgent
+Blueprints can be composed, remapped, and have transports overridden if `autoconnect()` fails due to conflicting variable names or `In[]` and `Out[]` message types.
-# Initialize robot
-robot = UnitreeGo2(ip=robot_ip,
- ros_control=UnitreeROSControl(),
- skills=MyUnitreeSkills())
+A blueprint example that connects the image stream from a robot to an LLM Agent for reasoning and action execution.
+```py
+from dimos.core.blueprints import autoconnect
+from dimos.core.transport import LCMTransport
+from dimos.msgs.sensor_msgs import Image
+from dimos.robot.unitree.connection.go2 import go2_connection
+from dimos.agents.agent import llm_agent
-# Initialize agent
-agent = OpenAIAgent(
- dev_name="UnitreeExecutionAgent",
- input_video_stream=robot.get_ros_video_stream(),
- skills=robot.get_skills(),
- system_query="Jump when you see a human! Front flip when you see a dog!",
- model_name="gpt-4o"
- )
+blueprint = autoconnect(
+ go2_connection(),
+ llm_agent(),
+).transports({("color_image", Image): LCMTransport("/color_image", Image)})
-while True: # keep process running
- time.sleep(1)
+# Run the blueprint
+blueprint.build().loop()
```
+# Development
-### DimOS Application with Agent chaining (OUT OF DATE)
-
-Let's build a simple DimOS application with Agent chaining. We define a ```planner``` as a ```PlanningAgent``` that takes in user input to devise a complex multi-step plan. This plan is passed step-by-step to an ```executor``` agent that can queue ```AbstractRobotSkill``` commands to the ```ROSCommandQueue```.
-
-Our reactive Pub/Sub data streaming architecture allows for chaining of ```Agent_0 --> Agent_1 --> ... --> Agent_n``` via the ```input_query_stream``` parameter in each which takes an ```Observable``` input from the previous Agent in the chain.
-
-**Via this method you can chain together any number of Agents() to create complex dimensional applications.**
-
-```python
-
-web_interface = RobotWebInterface(port=5555)
-
-robot = UnitreeGo2(ip=robot_ip,
- ros_control=UnitreeROSControl(),
- skills=MyUnitreeSkills())
-
-# Initialize master planning agent
-planner = PlanningAgent(
- dev_name="UnitreePlanningAgent",
- input_query_stream=web_interface.query_stream, # Takes user input from dimOS interface
- skills=robot.get_skills(),
- model_name="gpt-4o",
- )
-
-# Initialize execution agent
-executor = OpenAIAgent(
- dev_name="UnitreeExecutionAgent",
- input_query_stream=planner.get_response_observable(), # Takes planner output as input
- skills=robot.get_skills(),
- model_name="gpt-4o",
- system_query="""
- You are a robot execution agent that can execute tasks on a virtual
- robot. ONLY OUTPUT THE SKILLS TO EXECUTE.
- """
- )
-
-while True: # keep process running
- time.sleep(1)
-```
-
-### Calling Action Primitives (OUT OF DATE)
-
-Call action primitives directly from ```Robot()``` for prototyping and testing.
-
-```python
-robot = UnitreeGo2(ip=robot_ip,)
-
-# Call a Unitree WebRTC action primitive
-robot.webrtc_req(api_id=1016) # "Hello" command
-
-# Call a ROS2 action primitive
-robot.move(distance=1.0, speed=0.5)
-```
-
-### Creating Custom Skills (non-unitree specific)
-
-#### Create basic custom skills by inheriting from ```AbstractRobotSkill``` and implementing the ```__call__``` method.
-
-```python
-class Move(AbstractRobotSkill):
- distance: float = Field(...,description="Distance to reverse in meters")
- def __init__(self, robot: Optional[Robot] = None, **data):
- super().__init__(robot=robot, **data)
- def __call__(self):
- super().__call__()
- return self._robot.move(distance=self.distance)
-```
-
-#### Chain together skills to create recursive skill trees
-
-```python
-class JumpAndFlip(AbstractRobotSkill):
- def __init__(self, robot: Optional[Robot] = None, **data):
- super().__init__(robot=robot, **data)
- def __call__(self):
- super().__call__()
- jump = Jump(robot=self._robot)
- flip = Flip(robot=self._robot)
- return (jump() and flip())
+```sh
+GIT_LFS_SKIP_SMUDGE=1 git clone -b dev https://github.com/dimensionalOS/dimos.git
+cd dimos
```
-### Integrating Skills with Agents: Single Skills and Skill Libraries
-
-DimOS agents, such as `OpenAIAgent`, can be endowed with capabilities through two primary mechanisms: by providing them with individual skill classes or with comprehensive `SkillLibrary` instances. This design offers flexibility in how robot functionalities are defined and managed within your agent-based applications.
-
-**Agent's `skills` Parameter**
-
-The `skills` parameter in an agent's constructor is key to this integration:
-
-1. **A Single Skill Class**: This approach is suitable for skills that are relatively self-contained or have straightforward initialization requirements.
- * You pass the skill *class itself* (e.g., `GreeterSkill`) directly to the agent's `skills` parameter.
- * The agent then takes on the responsibility of instantiating this skill when it's invoked. This typically involves the agent providing necessary context to the skill's constructor (`__init__`), such as a `Robot` instance (or any other private instance variable) if the skill requires it.
+Then pick one of two development paths:
-2. **A `SkillLibrary` Instance**: This is the preferred method for managing a collection of skills, especially when skills have dependencies, require specific configurations, or need to share parameters.
- * You first define your custom skill library by inheriting from `SkillLibrary`. Then, you create and configure an *instance* of this library (e.g., `my_lib = EntertainmentSkills(robot=robot_instance)`).
- * This pre-configured `SkillLibrary` instance is then passed to the agent's `skills` parameter. The library itself manages the lifecycle and provision of its contained skills.
-
-**Examples:**
-
-#### 1. Using a Single Skill Class with an Agent
-
-First, define your skill. For instance, a `GreeterSkill` that can deliver a configurable greeting:
-
-```python
-class GreeterSkill(AbstractSkill):
- """Greats the user with a friendly message.""" # Gives the agent better context for understanding (the more detailed the better).
-
- greeting: str = Field(..., description="The greating message to display.") # The field needed for the calling of the function. Your agent will also pull from the description here to gain better context.
-
- def __init__(self, greeting_message: Optional[str] = None, **data):
- super().__init__(**data)
- if greeting_message:
- self.greeting = greeting_message
- # Any additional skill-specific initialization can go here
-
- def __call__(self):
- super().__call__() # Call parent's method if it contains base logic
- # Implement the logic for the skill
- print(self.greeting)
- return f"Greeting delivered: '{self.greeting}'"
-```
-
-Next, register this skill *class* directly with your agent. The agent can then instantiate it, potentially with specific configurations if your agent or skill supports it (e.g., via default parameters or a more advanced setup).
-
-```python
-agent = OpenAIAgent(
- dev_name="GreetingBot",
- system_query="You are a polite bot. If a user asks for a greeting, use your GreeterSkill.",
- skills=GreeterSkill, # Pass the GreeterSkill CLASS
- # The agent will instantiate GreeterSkill.
- # If the skill had required __init__ args not provided by the agent automatically,
- # this direct class passing might be insufficient without further agent logic
- # or by passing a pre-configured instance (see SkillLibrary example).
- # For simple skills like GreeterSkill with defaults or optional args, this works well.
- model_name="gpt-4o"
-)
+Option A: Devcontainer
+```sh
+./bin/dev
```
-In this setup, when the `GreetingBot` agent decides to use the `GreeterSkill`, it will instantiate it. If the `GreeterSkill` were to be instantiated by the agent with a specific `greeting_message`, the agent's design would need to support passing such parameters during skill instantiation.
-
-#### 2. Using a `SkillLibrary` Instance with an Agent
-
-Define the SkillLibrary and any skills it will manage in its collection:
-```python
-class MovementSkillsLibrary(SkillLibrary):
- """A specialized skill library containing movement and navigation related skills."""
-
- def __init__(self, robot=None):
- super().__init__()
- self._robot = robot
- def initialize_skills(self, robot=None):
- """Initialize all movement skills with the robot instance."""
- if robot:
- self._robot = robot
-
- if not self._robot:
- raise ValueError("Robot instance is required to initialize skills")
-
- # Initialize with all movement-related skills
- self.add(Navigate(robot=self._robot))
- self.add(NavigateToGoal(robot=self._robot))
- self.add(FollowHuman(robot=self._robot))
- self.add(NavigateToObject(robot=self._robot))
- self.add(GetPose(robot=self._robot)) # Position tracking skill
+Option B: Editable install with uv
+```sh
+uv venv && . .venv/bin/activate
+uv pip install -e '.[base,dev]'
```
-Note the addision of initialized skills added to this collection above.
+For system deps, Nix setups, and testing, see `/docs/development/README.md`.
-Proceed to use this skill library in an Agent:
+### Monitoring & Debugging
-Finally, in your main application code:
-```python
-# 1. Create an instance of your custom skill library, configured with the robot
-my_movement_skills = MovementSkillsLibrary(robot=robot_instance)
+DimOS comes with a number of monitoring tools:
+- Run `lcmspy` to see how fast messages are being published on streams.
+- Run `skillspy` to see how skills are being called, how long they are running, which are active, etc.
+- Run `agentspy` to see the agent's status over time.
+- If you suspect there is a bug within DimOS itself, you can enable extreme logging by prefixing the dimos command with `DIMOS_LOG_LEVEL=DEBUG RERUN_SAVE=1 `. Ex: `DIMOS_LOG_LEVEL=DEBUG RERUN_SAVE=1 dimos --replay run unitree-go2`
-# 2. Pass this library INSTANCE to the agent
-performing_agent = OpenAIAgent(
- dev_name="ShowBot",
- system_query="You are a show robot. Use your skills as directed.",
- skills=my_movement_skills, # Pass the configured SkillLibrary INSTANCE
- model_name="gpt-4o"
-)
-```
-
-### Unitree Test Files
-- **`tests/run_go2_ros.py`**: Tests `UnitreeROSControl(ROSControl)` initialization in `UnitreeGo2(Robot)` via direct function calls `robot.move()` and `robot.webrtc_req()`
-- **`tests/simple_agent_test.py`**: Tests a simple zero-shot class `OpenAIAgent` example
-- **`tests/unitree/test_webrtc_queue.py`**: Tests `ROSCommandQueue` via a 20 back-to-back WebRTC requests to the robot
-- **`tests/test_planning_agent_web_interface.py`**: Tests a simple two-stage `PlanningAgent` chained to an `ExecutionAgent` with backend FastAPI interface.
-- **`tests/test_unitree_agent_queries_fastapi.py`**: Tests a zero-shot `ExecutionAgent` with backend FastAPI interface.
-## Documentation
+# Documentation
-For detailed documentation, please visit our [documentation site](#) (Coming Soon).
+Concepts:
+- [Modules](/docs/concepts/modules.md): The building blocks of DimOS, modules run in parallel and are singleton python classes.
+- [Streams](/docs/api/sensor_streams/index.md): How modules communicate, a Pub / Sub system.
+- [Blueprints](/dimos/core/README_BLUEPRINTS.md): a way to group modules together and define their connections to each other.
+- [RPC](/dimos/core/README_BLUEPRINTS.md#calling-the-methods-of-other-modules): how one module can call a method on another module (arguments get serialized to JSON-like binary data).
+- [Skills](/dimos/core/README_BLUEPRINTS.md#defining-skills): An RPC function, except it can be called by an AI agent (a tool for an AI).
## Contributing
We welcome contributions! See our [Bounty List](https://docs.google.com/spreadsheets/d/1tzYTPvhO7Lou21cU6avSWTQOhACl5H8trSvhtYtsk8U/edit?usp=sharing) for open requests for contributions. If you would like to suggest a feature or sponsor a bounty, open an issue.
-
-## License
-
-This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
-
-## Acknowledgments
-
-Huge thanks to!
-- The Roboverse Community and their unitree-specific help. Check out their [Discord](https://discord.gg/HEXNMCNhEh).
-- @abizovnuralem for his work on the [Unitree Go2 ROS2 SDK](https://github.com/abizovnuralem/go2_ros2_sdk) we integrate with for DimOS.
-- @legion1581 for his work on the [Unitree Go2 WebRTC Connect](https://github.com/legion1581/go2_webrtc_connect) from which we've pulled the ```Go2WebRTCConnection``` class and other types for seamless WebRTC-only integration with DimOS.
-- @tfoldi for the webrtc_req integration via Unitree Go2 ROS2 SDK, which allows for seamless usage of Unitree WebRTC control primitives with DimOS.
-
-## Contact
-
-- GitHub Issues: For bug reports and feature requests
-- Email: [build@dimensionalOS.com](mailto:build@dimensionalOS.com)
-
-## Known Issues
-- Agent() failure to execute Nav2 action primitives (move, reverse, spinLeft, spinRight) is almost always due to the internal ROS2 collision avoidance, which will sometimes incorrectly display obstacles or be overly sensitive. Look for ```[behavior_server]: Collision Ahead - Exiting DriveOnHeading``` in the ROS logs. Reccomend restarting ROS2 or moving robot from objects to resolve.
-- ```docker-compose up --build``` does not fully initialize the ROS2 environment due to ```std::bad_alloc``` errors. This will occur during continuous docker development if the ```docker-compose down``` is not run consistently before rebuilding and/or you are on a machine with less RAM, as ROS is very memory intensive. Reccomend running to clear your docker cache/images/containers with ```docker system prune``` and rebuild.
diff --git a/bin/hooks/filter_commit_message.py b/bin/hooks/filter_commit_message.py
index cd92b196af..d22eaf9484 100644
--- a/bin/hooks/filter_commit_message.py
+++ b/bin/hooks/filter_commit_message.py
@@ -28,10 +28,16 @@ def main() -> int:
lines = commit_msg_file.read_text().splitlines(keepends=True)
- # Find the first line containing "Generated with" and truncate there
+ # Patterns that trigger truncation (everything from this line onwards is removed)
+ truncate_patterns = [
+ "Generated with",
+ "Co-Authored-By",
+ ]
+
+ # Find the first line containing any truncate pattern and truncate there
filtered_lines = []
for line in lines:
- if "Generated with" in line:
+ if any(pattern in line for pattern in truncate_patterns):
break
filtered_lines.append(line)
diff --git a/bin/pytest-fast b/bin/pytest-fast
new file mode 100755
index 0000000000..cb25f93288
--- /dev/null
+++ b/bin/pytest-fast
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+. .venv/bin/activate
+exec pytest "$@" dimos
diff --git a/bin/pytest-mujoco b/bin/pytest-mujoco
new file mode 100755
index 0000000000..07e7ed90bc
--- /dev/null
+++ b/bin/pytest-mujoco
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+. .venv/bin/activate
+exec pytest "$@" -m mujoco dimos
diff --git a/bin/pytest-slow b/bin/pytest-slow
new file mode 100755
index 0000000000..85643d4413
--- /dev/null
+++ b/bin/pytest-slow
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+. .venv/bin/activate
+exec pytest "$@" -m 'not (tool or module or neverending or mujoco)' dimos
diff --git a/data/.lfs/command_center.html.tar.gz b/data/.lfs/command_center.html.tar.gz
new file mode 100644
index 0000000000..f3089d7b87
--- /dev/null
+++ b/data/.lfs/command_center.html.tar.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2daabec1baf19c95cb50eadef8d3521289044adde03c86909351894cb90c9843
+size 137595
diff --git a/data/.lfs/models_edgetam.tar.gz b/data/.lfs/models_edgetam.tar.gz
new file mode 100644
index 0000000000..64baa5d139
--- /dev/null
+++ b/data/.lfs/models_edgetam.tar.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd452096f91415ce7ca90548a06a87354ccdb19a66925c0242413c80b08f5c57
+size 51988780
diff --git a/data/.lfs/models_yoloe.tar.gz b/data/.lfs/models_yoloe.tar.gz
new file mode 100644
index 0000000000..a0870d71d2
--- /dev/null
+++ b/data/.lfs/models_yoloe.tar.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a78e39477667b25c9454f846cd66dc044dd05981b2f7ebb0d331ef3626de9bc
+size 184892540
diff --git a/data/.lfs/person.tar.gz b/data/.lfs/person.tar.gz
new file mode 100644
index 0000000000..1f32d0db58
--- /dev/null
+++ b/data/.lfs/person.tar.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:332c3196c6436e7d4c2b7e3314b4a4055865ef358b2e9cf3c8ddd7e173f39b93
+size 2535758
diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py
new file mode 100644
index 0000000000..0d4420632c
--- /dev/null
+++ b/dimos/agents/skills/person_follow.py
@@ -0,0 +1,248 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from threading import Event, RLock
+import time
+from typing import TYPE_CHECKING
+
+import numpy as np
+from reactivex.disposable import Disposable
+
+from dimos.core.core import rpc
+from dimos.core.global_config import GlobalConfig
+from dimos.core.skill_module import SkillModule
+from dimos.core.stream import In, Out
+from dimos.models.qwen.video_query import BBox
+from dimos.models.segmentation.edge_tam import EdgeTAMProcessor
+from dimos.models.vl.qwen import QwenVlModel
+from dimos.msgs.geometry_msgs import Twist
+from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
+from dimos.navigation.visual.query import get_object_bbox_from_image
+from dimos.navigation.visual_servoing.detection_navigation import DetectionNavigation
+from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D
+from dimos.protocol.skill.skill import skill
+from dimos.utils.logging_config import setup_logger
+
+if TYPE_CHECKING:
+ from dimos.models.vl.base import VlModel
+
+logger = setup_logger()
+
+
+class PersonFollowSkillContainer(SkillModule):
+ """Skill container for following a person.
+
+ This skill uses:
+ - A VL model (QwenVlModel) to initially detect a person from a text description.
+ - EdgeTAM for continuous tracking across frames.
+ - Visual servoing OR 3D navigation to control robot movement towards the person.
+ - Does not do obstacle avoidance; assumes a clear path.
+ """
+
+ color_image: In[Image]
+ global_map: In[PointCloud2]
+ cmd_vel: Out[Twist]
+
+ _frequency: float = 20.0 # Hz - control loop frequency
+ _max_lost_frames: int = 15 # number of frames to wait before declaring person lost
+
+ def __init__(
+ self,
+ camera_info: CameraInfo,
+ global_config: GlobalConfig,
+ use_3d_navigation: bool = False,
+ ) -> None:
+ super().__init__()
+ self._global_config: GlobalConfig = global_config
+ self._use_3d_navigation: bool = use_3d_navigation
+ self._latest_image: Image | None = None
+ self._latest_pointcloud: PointCloud2 | None = None
+ self._vl_model: VlModel = QwenVlModel()
+ self._tracker: EdgeTAMProcessor | None = None
+ self._should_stop: Event = Event()
+ self._lock = RLock()
+
+ # Use MuJoCo camera intrinsics in simulation mode
+ if self._global_config.simulation:
+ from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection
+
+ camera_info = MujocoConnection.camera_info_static
+
+ self._camera_info = camera_info
+ self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation)
+ self._detection_navigation = DetectionNavigation(self.tf, camera_info)
+
+ @rpc
+ def start(self) -> None:
+ super().start()
+ self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image)))
+ if self._use_3d_navigation:
+ self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud)))
+
+ @rpc
+ def stop(self) -> None:
+ self._stop_following()
+
+ with self._lock:
+ if self._tracker is not None:
+ self._tracker.stop()
+ self._tracker = None
+
+ self._vl_model.stop()
+ super().stop()
+
+ @skill()
+ def follow_person(self, query: str) -> str:
+ """Follow a person matching the given description using visual servoing.
+
+ The robot will continuously track and follow the person, while keeping
+ them centered in the camera view.
+
+ Args:
+ query: Description of the person to follow (e.g., "man with blue shirt")
+
+ Returns:
+ Status message indicating the result of the following action.
+
+ Example:
+ follow_person("man with blue shirt")
+ follow_person("person in the doorway")
+ """
+
+ self._stop_following()
+
+ self._should_stop.clear()
+
+ with self._lock:
+ latest_image = self._latest_image
+
+ if latest_image is None:
+ return "No image available to detect person."
+
+ initial_bbox = get_object_bbox_from_image(
+ self._vl_model,
+ latest_image,
+ query,
+ )
+
+ if initial_bbox is None:
+ return f"Could not find '{query}' in the current view."
+
+ return self._follow_loop(query, initial_bbox)
+
+ @skill()
+ def stop_following(self) -> str:
+ """Stop following the current person.
+
+ Returns:
+ Confirmation message.
+ """
+ self._stop_following()
+
+ self.cmd_vel.publish(Twist.zero())
+
+ return "Stopped following."
+
+ def _on_color_image(self, image: Image) -> None:
+ with self._lock:
+ self._latest_image = image
+
+ def _on_pointcloud(self, pointcloud: PointCloud2) -> None:
+ with self._lock:
+ self._latest_pointcloud = pointcloud
+
+ def _follow_loop(self, query: str, initial_bbox: BBox) -> str:
+ x1, y1, x2, y2 = initial_bbox
+ box = np.array([x1, y1, x2, y2], dtype=np.float32)
+
+ with self._lock:
+ if self._tracker is None:
+ self._tracker = EdgeTAMProcessor()
+ tracker = self._tracker
+ latest_image = self._latest_image
+ if latest_image is None:
+ return "No image available to start tracking."
+
+ initial_detections = tracker.init_track(
+ image=latest_image,
+ box=box,
+ obj_id=1,
+ )
+
+ if len(initial_detections) == 0:
+ self.cmd_vel.publish(Twist.zero())
+ return f"EdgeTAM failed to segment '{query}'."
+
+ logger.info(f"EdgeTAM initialized with {len(initial_detections)} detections")
+
+ lost_count = 0
+ period = 1.0 / self._frequency
+ next_time = time.monotonic()
+
+ while not self._should_stop.is_set():
+ next_time += period
+
+ with self._lock:
+ latest_image = self._latest_image
+ assert latest_image is not None
+
+ detections = tracker.process_image(latest_image)
+
+ if len(detections) == 0:
+ self.cmd_vel.publish(Twist.zero())
+
+ lost_count += 1
+ if lost_count > self._max_lost_frames:
+ self.cmd_vel.publish(Twist.zero())
+ return f"Lost track of '{query}'. Stopping."
+ else:
+ lost_count = 0
+ best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume())
+
+ if self._use_3d_navigation:
+ with self._lock:
+ pointcloud = self._latest_pointcloud
+ if pointcloud is None:
+ self.cmd_vel.publish(Twist.zero())
+ return "No pointcloud available for 3D navigation. Stopping."
+ twist = self._detection_navigation.compute_twist_for_detection_3d(
+ pointcloud,
+ best_detection,
+ latest_image,
+ )
+ if twist is None:
+ self.cmd_vel.publish(Twist.zero())
+ return f"3D navigation failed for '{query}'. Stopping."
+ else:
+ twist = self._visual_servo.compute_twist(
+ best_detection.bbox,
+ latest_image.width,
+ )
+ self.cmd_vel.publish(twist)
+
+ now = time.monotonic()
+ sleep_duration = next_time - now
+ if sleep_duration > 0:
+ time.sleep(sleep_duration)
+
+ self.cmd_vel.publish(Twist.zero())
+ return "Stopped following as requested."
+
+ def _stop_following(self) -> None:
+ self._should_stop.set()
+
+
+person_follow_skill = PersonFollowSkillContainer.blueprint
+
+__all__ = ["PersonFollowSkillContainer", "person_follow_skill"]
diff --git a/dimos/agents/skills/test_navigation.py b/dimos/agents/skills/test_navigation.py
index 588b55a602..67e0429cb5 100644
--- a/dimos/agents/skills/test_navigation.py
+++ b/dimos/agents/skills/test_navigation.py
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
from dimos.msgs.geometry_msgs import PoseStamped, Vector3
from dimos.utils.transform_utils import euler_to_quaternion
-# @pytest.mark.skip
def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker) -> None:
cancel_goal_mock = mocker.Mock()
stop_exploration_mock = mocker.Mock()
@@ -35,6 +35,7 @@ def test_stop_movement(create_navigation_agent, navigation_skill_container, mock
stop_exploration_mock.assert_called_once_with()
+@pytest.mark.integration
def test_take_a_look_around(create_navigation_agent, navigation_skill_container, mocker) -> None:
explore_mock = mocker.Mock()
is_exploration_active_mock = mocker.Mock()
diff --git a/dimos/agents/test_agent_fake.py b/dimos/agents/test_agent_fake.py
index 367985a356..e544765758 100644
--- a/dimos/agents/test_agent_fake.py
+++ b/dimos/agents/test_agent_fake.py
@@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
+
+@pytest.mark.integration
def test_what_is_your_name(create_potato_agent) -> None:
agent = create_potato_agent(fixture="test_what_is_your_name.json")
response = agent.query("hi there, please tell me what's your name?")
assert "Mr. Potato" in response
+@pytest.mark.integration
def test_how_much_is_124181112_plus_124124(create_potato_agent) -> None:
agent = create_potato_agent(fixture="test_how_much_is_124181112_plus_124124.json")
@@ -29,6 +33,7 @@ def test_how_much_is_124181112_plus_124124(create_potato_agent) -> None:
assert "999000000" in response.replace(",", "")
+@pytest.mark.integration
def test_what_do_you_see_in_this_picture(create_potato_agent) -> None:
agent = create_potato_agent(fixture="test_what_do_you_see_in_this_picture.json")
diff --git a/dimos/agents/test_mock_agent.py b/dimos/agents/test_mock_agent.py
index c711e23143..4f449e973a 100644
--- a/dimos/agents/test_mock_agent.py
+++ b/dimos/agents/test_mock_agent.py
@@ -24,12 +24,12 @@
from dimos.agents.testing import MockModel
from dimos.core import LCMTransport, start
from dimos.msgs.geometry_msgs import PoseStamped, Vector3
-from dimos.msgs.sensor_msgs import Image
+from dimos.msgs.sensor_msgs import Image, PointCloud2
from dimos.protocol.skill.test_coordinator import SkillContainerTest
from dimos.robot.unitree.connection.go2 import GO2Connection
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+@pytest.mark.integration
def test_tool_call() -> None:
"""Test agent initialization and tool call execution."""
# Create a fake model that will respond with tool calls
@@ -74,6 +74,7 @@ def test_tool_call() -> None:
agent.stop()
+@pytest.mark.integration
def test_image_tool_call() -> None:
"""Test agent with image tool call execution."""
dimos = start(2)
@@ -158,7 +159,7 @@ def test_tool_call_implicit_detections() -> None:
)
robot_connection = dimos.deploy(GO2Connection, connection_type="fake")
- robot_connection.lidar.transport = LCMTransport("/lidar", LidarMessage)
+ robot_connection.lidar.transport = LCMTransport("/lidar", PointCloud2)
robot_connection.odom.transport = LCMTransport("/odom", PoseStamped)
robot_connection.video.transport = LCMTransport("/image", Image)
robot_connection.cmd_vel.transport = LCMTransport("/cmd_vel", Vector3)
diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py
index 0757a59d22..0b99fe4d1c 100644
--- a/dimos/agents/vlm_agent.py
+++ b/dimos/agents/vlm_agent.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
+from typing import Any
+
+from langchain_core.messages import AIMessage, HumanMessage
from dimos.agents.llm_init import build_llm, build_system_message
from dimos.agents.spec import AgentSpec, AnyMessage
@@ -31,7 +33,7 @@ class VLMAgent(AgentSpec):
query_stream: In[HumanMessage]
answer_stream: Out[AIMessage]
- def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._llm = build_llm(self.config)
self._latest_image: Image | None = None
@@ -71,18 +73,23 @@ def _extract_text(self, msg: HumanMessage) -> str:
return str(part.get("text", ""))
return str(content)
- def _invoke(self, msg: HumanMessage) -> AIMessage:
+ def _invoke(self, msg: HumanMessage, **kwargs: Any) -> AIMessage:
messages = [self._system_message, msg]
- response = self._llm.invoke(messages)
+ response = self._llm.invoke(messages, **kwargs)
self.append_history([msg, response]) # type: ignore[arg-type]
return response # type: ignore[return-value]
- def _invoke_image(self, image: Image, query: str) -> AIMessage:
+ def _invoke_image(
+ self, image: Image, query: str, response_format: dict[str, Any] | None = None
+ ) -> AIMessage:
content = [{"type": "text", "text": query}, *image.agent_encode()]
- return self._invoke(HumanMessage(content=content))
+ kwargs: dict[str, Any] = {}
+ if response_format:
+ kwargs["response_format"] = response_format
+ return self._invoke(HumanMessage(content=content), **kwargs)
@rpc
- def clear_history(self): # type: ignore[no-untyped-def]
+ def clear_history(self) -> None:
self._history.clear()
def append_history(self, *msgs: list[AIMessage | HumanMessage]) -> None:
@@ -95,9 +102,7 @@ def history(self) -> list[AnyMessage]:
return [self._system_message, *self._history]
@rpc
- def register_skills( # type: ignore[no-untyped-def]
- self, container, run_implicit_name: str | None = None
- ) -> None:
+ def register_skills(self, container: Any, run_implicit_name: str | None = None) -> None:
logger.warning(
"VLMAgent does not manage skills; register_skills is a no-op",
container=str(container),
@@ -105,14 +110,18 @@ def register_skills( # type: ignore[no-untyped-def]
)
@rpc
- def query(self, query: str): # type: ignore[no-untyped-def]
+ def query(self, query: str) -> str:
response = self._invoke(HumanMessage(query))
- return response.content
+ content = response.content
+ return content if isinstance(content, str) else str(content)
@rpc
- def query_image(self, image: Image, query: str): # type: ignore[no-untyped-def]
- response = self._invoke_image(image, query)
- return response.content
+ def query_image(
+ self, image: Image, query: str, response_format: dict[str, Any] | None = None
+ ) -> str:
+ response = self._invoke_image(image, query, response_format=response_format)
+ content = response.content
+ return content if isinstance(content, str) else str(content)
vlm_agent = VLMAgent.blueprint
diff --git a/dimos/agents_deprecated/memory/spatial_vector_db.py b/dimos/agents_deprecated/memory/spatial_vector_db.py
index 0c8774cd95..c482076325 100644
--- a/dimos/agents_deprecated/memory/spatial_vector_db.py
+++ b/dimos/agents_deprecated/memory/spatial_vector_db.py
@@ -225,8 +225,8 @@ def _process_query_results(self, results) -> list[dict]: # type: ignore[no-unty
)
# Get the image from visual memory
- image = self.visual_memory.get(lookup_id)
- result["image"] = image
+ #image = self.visual_memory.get(lookup_id)
+ #result["image"] = image
processed_results.append(result)
diff --git a/dimos/assets/foxglove_dashboards/Overwatch.json b/dimos/assets/foxglove_dashboards/Overwatch.json
new file mode 100644
index 0000000000..5a26abbc90
--- /dev/null
+++ b/dimos/assets/foxglove_dashboards/Overwatch.json
@@ -0,0 +1,522 @@
+{
+ "_watermark": {
+ "creator": "DIMOS Navigation",
+ "author": "bona",
+ "organization": "Dimensional",
+ "version": "1.0",
+ "timestamp": "2026-01-17",
+ "signature": "dimos-nav-overwatch-layout"
+ },
+ "configById": {
+ "3D!main": {
+ "cameraState": {
+ "perspective": true,
+ "distance": 11.376001845526945,
+ "phi": 60.000000000004256,
+ "thetaOffset": -154.65065502183666,
+ "targetOffset": [
+ -1.3664627921863377,
+ -0.4507491962036497,
+ 2.2548288625522925e-16
+ ],
+ "target": [
+ 0,
+ 0,
+ 0
+ ],
+ "targetOrientation": [
+ 0,
+ 0,
+ 0,
+ 1
+ ],
+ "fovy": 45,
+ "near": 0.01,
+ "far": 5000
+ },
+ "followMode": "follow-none",
+ "followTf": "map",
+ "scene": {
+ "backgroundColor": "#000000",
+ "enableStats": false,
+ "syncCamera": false
+ },
+ "transforms": {
+ "frame:vehicle": {
+ "visible": true
+ },
+ "frame:map": {
+ "visible": true
+ }
+ },
+ "topics": {
+ "/registered_scan": {
+ "visible": true,
+ "colorField": "z",
+ "colorMode": "colormap",
+ "flatColor": "#ffffff",
+ "pointSize": 2,
+ "decayTime": 1
+ },
+ "/overall_map": {
+ "visible": true,
+ "colorField": "intensity",
+ "colorMode": "flat",
+ "flatColor": "#ffffff",
+ "pointSize": 2,
+ "decayTime": 0
+ },
+ "/free_paths": {
+ "visible": true,
+ "colorField": "intensity",
+ "colorMode": "colormap",
+ "colorMap": "turbo",
+ "pointSize": 2
+ },
+ "/path": {
+ "visible": true,
+ "type": "line",
+ "lineWidth": 0.05,
+ "color": "#19ff00"
+ },
+ "/way_point": {
+ "visible": true,
+ "color": "#cc29cc"
+ },
+ "/navigation_boundary": {
+ "visible": true,
+ "color": "#00ff00",
+ "lineWidth": 0.2
+ },
+ "/goal_pose": {
+ "visible": true,
+ "color": "#ff1900",
+ "type": "arrow"
+ },
+ "/terrain_map": {
+ "visible": false,
+ "colorField": "intensity",
+ "colorMode": "colormap",
+ "colorMap": "rainbow",
+ "pointSize": 4
+ },
+ "/terrain_map_ext": {
+ "visible": false,
+ "colorField": "intensity",
+ "colorMode": "colormap",
+ "colorMap": "rainbow",
+ "pointSize": 4
+ },
+ "/sensor_scan": {
+ "visible": false,
+ "colorField": "intensity",
+ "colorMode": "flat",
+ "flatColor": "#ffffff",
+ "pointSize": 2
+ },
+ "/added_obstacles": {
+ "visible": false,
+ "colorField": "intensity",
+ "colorMode": "flat",
+ "flatColor": "#ff1900",
+ "pointSize": 3
+ },
+ "/explored_areas": {
+ "visible": false,
+ "colorField": "intensity",
+ "colorMode": "flat",
+ "flatColor": "#00aaff",
+ "pointSize": 2
+ },
+ "/trajectory": {
+ "visible": false,
+ "colorField": "intensity",
+ "colorMode": "colormap",
+ "colorMap": "rainbow",
+ "pointSize": 7
+ },
+ "/viz_graph_topic": {
+ "visible": true,
+ "namespaces": {
+ "angle_direct": {
+ "visible": false
+ },
+ "boundary_edge": {
+ "visible": false
+ },
+ "boundary_vertex": {
+ "visible": false
+ },
+ "freespace_vertex": {
+ "visible": false
+ },
+ "freespace_vgraph": {
+ "visible": true
+ },
+ "frontier_vertex": {
+ "visible": false
+ },
+ "global_vertex": {
+ "visible": false
+ },
+ "global_vgraph": {
+ "visible": false
+ },
+ "localrange_vertex": {
+ "visible": false
+ },
+ "odom_edge": {
+ "visible": false
+ },
+ "polygon_edge": {
+ "visible": true
+ },
+ "to_goal_edge": {
+ "visible": false
+ },
+ "trajectory_edge": {
+ "visible": false
+ },
+ "trajectory_vertex": {
+ "visible": false
+ },
+ "updating_vertex": {
+ "visible": false
+ },
+ "vertex_angle": {
+ "visible": false
+ },
+ "vertices_matches": {
+ "visible": false
+ },
+ "visibility_edge": {
+ "visible": false
+ }
+ }
+ }
+ },
+ "layers": {
+ "grid": {
+ "layerId": "foxglove.Grid",
+ "visible": true,
+ "frameLocked": true,
+ "label": "Grid",
+ "position": [
+ 0,
+ 0,
+ 0
+ ],
+ "rotation": [
+ 0,
+ 0,
+ 0
+ ],
+ "color": "#248f24",
+ "size": 100,
+ "divisions": 100,
+ "lineWidth": 1,
+ "frameId": "map"
+ }
+ },
+ "publish": {
+ "type": "pose",
+ "poseTopic": "/goal_pose",
+ "pointTopic": "/way_point",
+ "poseEstimateTopic": "/initialpose",
+ "poseEstimateXDeviation": 0.5,
+ "poseEstimateYDeviation": 0.5,
+ "poseEstimateThetaDeviation": 0.26179939,
+ "frameId": "map"
+ },
+ "imageMode": {},
+ "fixedFrame": "map"
+ },
+ "Image!camera": {
+ "cameraTopic": "/camera/image",
+ "enabledMarkerTopics": [],
+ "synchronize": false,
+ "transformMarkers": true,
+ "smooth": false,
+ "flipHorizontal": false,
+ "flipVertical": false,
+ "minValue": 0,
+ "maxValue": 1,
+ "rotation": 0,
+ "cameraState": {
+ "distance": 20,
+ "perspective": true,
+ "phi": 60,
+ "target": [
+ 0,
+ 0,
+ 0
+ ],
+ "targetOffset": [
+ 0,
+ 0,
+ 0
+ ],
+ "targetOrientation": [
+ 0,
+ 0,
+ 0,
+ 1
+ ],
+ "thetaOffset": 45,
+ "fovy": 45,
+ "near": 0.5,
+ "far": 5000
+ },
+ "followMode": "follow-pose",
+ "scene": {},
+ "transforms": {},
+ "topics": {},
+ "layers": {},
+ "publish": {
+ "type": "point",
+ "poseTopic": "/move_base_simple/goal",
+ "pointTopic": "/clicked_point",
+ "poseEstimateTopic": "/initialpose",
+ "poseEstimateXDeviation": 0.5,
+ "poseEstimateYDeviation": 0.5,
+ "poseEstimateThetaDeviation": 0.26179939
+ },
+ "imageMode": {}
+ },
+ "Image!semantic": {
+ "cameraTopic": "/camera/semantic_image",
+ "enabledMarkerTopics": [],
+ "synchronize": false,
+ "transformMarkers": true,
+ "smooth": false,
+ "flipHorizontal": false,
+ "flipVertical": false,
+ "minValue": 0,
+ "maxValue": 1,
+ "rotation": 0,
+ "cameraState": {
+ "distance": 20,
+ "perspective": true,
+ "phi": 60,
+ "target": [
+ 0,
+ 0,
+ 0
+ ],
+ "targetOffset": [
+ 0,
+ 0,
+ 0
+ ],
+ "targetOrientation": [
+ 0,
+ 0,
+ 0,
+ 1
+ ],
+ "thetaOffset": 45,
+ "fovy": 45,
+ "near": 0.5,
+ "far": 5000
+ },
+ "followMode": "follow-pose",
+ "scene": {},
+ "transforms": {},
+ "topics": {},
+ "layers": {},
+ "publish": {
+ "type": "point",
+ "poseTopic": "/move_base_simple/goal",
+ "pointTopic": "/clicked_point",
+ "poseEstimateTopic": "/initialpose",
+ "poseEstimateXDeviation": 0.5,
+ "poseEstimateYDeviation": 0.5,
+ "poseEstimateThetaDeviation": 0.26179939
+ },
+ "imageMode": {}
+ },
+ "Teleop!teleop": {
+ "topic": "/foxglove_teleop",
+ "publishRate": 10,
+ "upButton": {
+ "field": "linear.x",
+ "value": 0.5
+ },
+ "downButton": {
+ "field": "linear.x",
+ "value": -0.5
+ },
+ "leftButton": {
+ "field": "angular.z",
+ "value": 0.5
+ },
+ "rightButton": {
+ "field": "angular.z",
+ "value": -0.5
+ }
+ },
+ "RawMessages!odom": {
+ "diffEnabled": false,
+ "diffMethod": "custom",
+ "diffTopicPath": "",
+ "showFullMessageForDiff": false,
+ "topicPath": "/state_estimation.pose.pose"
+ },
+ "RawMessages!cmdvel": {
+ "diffEnabled": false,
+ "diffMethod": "custom",
+ "diffTopicPath": "",
+ "showFullMessageForDiff": false,
+ "topicPath": "/cmd_vel.twist",
+ "fontSize": 12
+ },
+ "Plot!speed": {
+ "paths": [
+ {
+ "value": "/cmd_vel.twist.linear.x",
+ "enabled": true,
+ "timestampMethod": "receiveTime",
+ "label": "Linear X"
+ },
+ {
+ "value": "/cmd_vel.twist.linear.y",
+ "enabled": true,
+ "timestampMethod": "receiveTime",
+ "label": "Linear Y"
+ },
+ {
+ "value": "/cmd_vel.twist.angular.z",
+ "enabled": true,
+ "timestampMethod": "receiveTime",
+ "label": "Angular Z"
+ }
+ ],
+ "showXAxisLabels": true,
+ "showYAxisLabels": true,
+ "showLegend": true,
+ "legendDisplay": "floating",
+ "showPlotValuesInLegend": true,
+ "isSynced": true,
+ "xAxisVal": "timestamp",
+ "sidebarDimension": 240,
+ "minYValue": -2,
+ "maxYValue": 2,
+ "followingViewWidth": 30
+ },
+ "Indicator!goalreached": {
+ "path": "/goal_reached.data",
+ "style": "background",
+ "fallbackColor": "#a0a0a0",
+ "fallbackLabel": "No Data",
+ "rules": [
+ {
+ "operator": "=",
+ "rawValue": "true",
+ "color": "#68e24a",
+ "label": "Goal Reached"
+ },
+ {
+ "operator": "=",
+ "rawValue": "false",
+ "color": "#f5f5f5",
+ "label": "Navigating"
+ }
+ ],
+ "fontSize": 36
+ },
+ "Indicator!stop": {
+ "path": "/stop.data",
+ "style": "background",
+ "fallbackColor": "#a0a0a0",
+ "fallbackLabel": "No Data",
+ "rules": [
+ {
+ "operator": "=",
+ "rawValue": "0",
+ "color": "#68e24a",
+ "label": "OK"
+ },
+ {
+ "operator": "=",
+ "rawValue": "1",
+ "color": "#f5ba42",
+ "label": "Speed Stop"
+ },
+ {
+ "operator": "=",
+ "rawValue": "2",
+ "color": "#eb4034",
+ "label": "Full Stop"
+ }
+ ],
+ "fontSize": 36
+ },
+ "Indicator!autonomy": {
+ "path": "/joy.axes[2]",
+ "style": "background",
+ "fallbackColor": "#a0a0a0",
+ "fallbackLabel": "No Joystick",
+ "rules": [
+ {
+ "operator": "<",
+ "rawValue": "-0.1",
+ "color": "#68e24a",
+ "label": "Autonomy ON"
+ },
+ {
+ "operator": ">=",
+ "rawValue": "-0.1",
+ "color": "#eb4034",
+ "label": "Autonomy OFF"
+ }
+ ],
+ "fontSize": 36
+ }
+ },
+ "globalVariables": {},
+ "userNodes": {},
+ "playbackConfig": {
+ "speed": 1
+ },
+ "layout": {
+ "first": {
+ "first": "3D!main",
+ "second": {
+ "first": "Image!camera",
+ "second": "Image!semantic",
+ "direction": "column",
+ "splitPercentage": 50
+ },
+ "direction": "row",
+ "splitPercentage": 70
+ },
+ "second": {
+ "first": {
+ "first": "Teleop!teleop",
+ "second": {
+ "first": "Indicator!autonomy",
+ "second": {
+ "first": "Indicator!goalreached",
+ "second": "Indicator!stop",
+ "direction": "row",
+ "splitPercentage": 50
+ },
+ "direction": "row",
+ "splitPercentage": 33
+ },
+ "direction": "row",
+ "splitPercentage": 40
+ },
+ "second": {
+ "first": "RawMessages!cmdvel",
+ "second": "Plot!speed",
+ "direction": "row",
+ "splitPercentage": 30
+ },
+ "direction": "column",
+ "splitPercentage": 40
+ },
+ "direction": "column",
+ "splitPercentage": 75
+ }
+}
diff --git a/dimos/control/README.md b/dimos/control/README.md
new file mode 100644
index 0000000000..58490321fa
--- /dev/null
+++ b/dimos/control/README.md
@@ -0,0 +1,195 @@
+# Control Orchestrator
+
+Centralized control system for multi-arm robots with per-joint arbitration.
+
+## Architecture
+
+```
+┌─────────────────────────────────────────────────────────────┐
+│ ControlOrchestrator │
+│ │
+│ ┌──────────────────────────────────────────────────────┐ │
+│ │ TickLoop (100Hz) │ │
+│ │ │ │
+│ │ READ ──► COMPUTE ──► ARBITRATE ──► ROUTE ──► WRITE │ │
+│ └──────────────────────────────────────────────────────┘ │
+│ │ │ │ │ │
+│ ▼ ▼ ▼ ▼ │
+│ ┌─────────┐ ┌───────┐ ┌─────────┐ ┌──────────┐ │
+│ │Hardware │ │ Tasks │ │Priority │ │ Backends │ │
+│ │Interface│ │ │ │ Winners │ │ │ │
+│ └─────────┘ └───────┘ └─────────┘ └──────────┘ │
+└─────────────────────────────────────────────────────────────┘
+```
+
+## Quick Start
+
+```bash
+# Terminal 1: Run orchestrator
+dimos run orchestrator-mock # Single 7-DOF mock arm
+dimos run orchestrator-dual-mock # Dual arms (7+6 DOF)
+dimos run orchestrator-piper-xarm # Real hardware
+
+# Terminal 2: Control via CLI
+python -m dimos.manipulation.control.orchestrator_client
+```
+
+## Core Concepts
+
+### Tick Loop
+Single deterministic loop at 100Hz:
+1. **Read** - Get joint positions from all hardware
+2. **Compute** - Each task calculates desired output
+3. **Arbitrate** - Per-joint, highest priority wins
+4. **Route** - Group commands by hardware
+5. **Write** - Send commands to backends
+
+### Tasks (Controllers)
+Tasks are passive controllers called by the orchestrator:
+
+```python
+class MyController:
+ def claim(self) -> ResourceClaim:
+ return ResourceClaim(joints={"joint1", "joint2"}, priority=10)
+
+ def compute(self, state: OrchestratorState) -> JointCommandOutput:
+ # Your control law here (PID, impedance, etc.)
+ return JointCommandOutput(
+ joint_names=["joint1", "joint2"],
+ positions=[0.5, 0.3],
+ mode=ControlMode.POSITION,
+ )
+```
+
+### Priority & Arbitration
+Higher priority always wins. Arbitration happens every tick:
+
+```
+traj_arm (priority=10) wants joint1 = 0.5
+safety (priority=100) wants joint1 = 0.0
+ ↓
+ safety wins, traj_arm preempted
+```
+
+### Preemption
+When a task loses a joint to higher priority, it gets notified:
+
+```python
+def on_preempted(self, by_task: str, joints: frozenset[str]) -> None:
+ self._state = TrajectoryState.PREEMPTED
+```
+
+## Files
+
+```
+dimos/control/
+├── orchestrator.py # Module + RPC interface
+├── tick_loop.py # 100Hz control loop
+├── task.py # ControlTask protocol + types
+├── hardware_interface.py # Backend wrapper
+├── blueprints.py # Pre-configured setups
+└── tasks/
+ └── trajectory_task.py # Joint trajectory controller
+```
+
+## Configuration
+
+```python
+from dimos.control import control_orchestrator, HardwareConfig, TaskConfig
+
+my_robot = control_orchestrator(
+ tick_rate=100.0,
+ hardware=[
+ HardwareConfig(id="left", type="xarm", dof=7, joint_prefix="left", ip="192.168.1.100"),
+ HardwareConfig(id="right", type="piper", dof=6, joint_prefix="right", can_port="can0"),
+ ],
+ tasks=[
+ TaskConfig(name="traj_left", type="trajectory", joint_names=[...], priority=10),
+ TaskConfig(name="traj_right", type="trajectory", joint_names=[...], priority=10),
+ TaskConfig(name="safety", type="trajectory", joint_names=[...], priority=100),
+ ],
+)
+```
+
+## RPC Methods
+
+| Method | Description |
+|--------|-------------|
+| `list_hardware()` | List hardware IDs |
+| `list_joints()` | List all joint names |
+| `list_tasks()` | List task names |
+| `get_joint_positions()` | Get current positions |
+| `execute_trajectory(task, traj)` | Execute trajectory |
+| `get_trajectory_status(task)` | Get task status |
+| `cancel_trajectory(task)` | Cancel active trajectory |
+
+## Control Modes
+
+Tasks output commands in one of three modes:
+
+| Mode | Output | Use Case |
+|------|--------|----------|
+| POSITION | `q` | Trajectory following |
+| VELOCITY | `q_dot` | Joystick teleoperation |
+| TORQUE | `tau` | Force control, impedance |
+
+## Writing a Custom Task
+
+```python
+from dimos.control.task import ControlTask, ResourceClaim, JointCommandOutput, ControlMode
+
+class PIDController:
+ def __init__(self, joints: list[str], priority: int = 10):
+ self._name = "pid_controller"
+ self._claim = ResourceClaim(joints=frozenset(joints), priority=priority)
+ self._joints = joints
+ self.Kp, self.Ki, self.Kd = 10.0, 0.1, 1.0
+ self._integral = [0.0] * len(joints)
+ self._last_error = [0.0] * len(joints)
+ self.target = [0.0] * len(joints)
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ def claim(self) -> ResourceClaim:
+ return self._claim
+
+ def is_active(self) -> bool:
+ return True
+
+ def compute(self, state) -> JointCommandOutput:
+ positions = [state.joints.joint_positions[j] for j in self._joints]
+ error = [t - p for t, p in zip(self.target, positions)]
+
+ # PID
+ self._integral = [i + e * state.dt for i, e in zip(self._integral, error)]
+ derivative = [(e - le) / state.dt for e, le in zip(error, self._last_error)]
+ output = [self.Kp*e + self.Ki*i + self.Kd*d
+ for e, i, d in zip(error, self._integral, derivative)]
+ self._last_error = error
+
+ return JointCommandOutput(
+ joint_names=self._joints,
+ positions=output,
+ mode=ControlMode.POSITION,
+ )
+
+ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None:
+ pass # Handle preemption
+```
+
+## Joint State Output
+
+The orchestrator publishes one aggregated `JointState` message containing all joints:
+
+```python
+JointState(
+ name=["left_joint1", ..., "right_joint1", ...], # All joints
+ position=[...],
+ velocity=[...],
+ effort=[...],
+)
+```
+
+Subscribe via: `/orchestrator/joint_state`
diff --git a/dimos/control/__init__.py b/dimos/control/__init__.py
new file mode 100644
index 0000000000..3d7d647cd4
--- /dev/null
+++ b/dimos/control/__init__.py
@@ -0,0 +1,92 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ControlOrchestrator - Centralized control for multi-arm coordination.
+
+This module provides a centralized control orchestrator that replaces
+per-driver/per-controller loops with a single deterministic tick-based system.
+
+Features:
+- Single tick loop (read → compute → arbitrate → route → write)
+- Per-joint arbitration (highest priority wins)
+- Mode conflict detection
+- Partial command support (hold last value)
+- Aggregated preemption notifications
+
+Example:
+ >>> from dimos.control import ControlOrchestrator
+ >>> from dimos.control.tasks import JointTrajectoryTask, JointTrajectoryTaskConfig
+ >>> from dimos.hardware.manipulators.xarm import XArmBackend
+ >>>
+ >>> # Create orchestrator
+ >>> orch = ControlOrchestrator(tick_rate=100.0)
+ >>>
+ >>> # Add hardware
+ >>> backend = XArmBackend(ip="192.168.1.185", dof=7)
+ >>> backend.connect()
+ >>> orch.add_hardware("left_arm", backend, joint_prefix="left")
+ >>>
+ >>> # Add task
+ >>> joints = [f"left_joint{i+1}" for i in range(7)]
+ >>> task = JointTrajectoryTask(
+ ... "traj_left",
+ ... JointTrajectoryTaskConfig(joint_names=joints, priority=10),
+ ... )
+ >>> orch.add_task(task)
+ >>>
+ >>> # Start
+ >>> orch.start()
+"""
+
+from dimos.control.hardware_interface import (
+ BackendHardwareInterface,
+ HardwareInterface,
+)
+from dimos.control.orchestrator import (
+ ControlOrchestrator,
+ ControlOrchestratorConfig,
+ HardwareConfig,
+ TaskConfig,
+ control_orchestrator,
+)
+from dimos.control.task import (
+ ControlMode,
+ ControlTask,
+ JointCommandOutput,
+ JointStateSnapshot,
+ OrchestratorState,
+ ResourceClaim,
+)
+from dimos.control.tick_loop import TickLoop
+
+__all__ = [
+ # Hardware interface
+ "BackendHardwareInterface",
+ "ControlMode",
+ # Orchestrator
+ "ControlOrchestrator",
+ "ControlOrchestratorConfig",
+ # Task protocol and types
+ "ControlTask",
+ "HardwareConfig",
+ "HardwareInterface",
+ "JointCommandOutput",
+ "JointStateSnapshot",
+ "OrchestratorState",
+ "ResourceClaim",
+ "TaskConfig",
+ # Tick loop
+ "TickLoop",
+ "control_orchestrator",
+]
diff --git a/dimos/control/blueprints.py b/dimos/control/blueprints.py
new file mode 100644
index 0000000000..d38ac1f81f
--- /dev/null
+++ b/dimos/control/blueprints.py
@@ -0,0 +1,366 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Pre-configured blueprints for the ControlOrchestrator.
+
+This module provides ready-to-use orchestrator blueprints for common setups.
+
+Usage:
+ # Run via CLI:
+ dimos run orchestrator-mock # Mock 7-DOF arm
+ dimos run orchestrator-xarm7 # XArm7 real hardware
+ dimos run orchestrator-dual-mock # Dual mock arms
+
+ # Or programmatically:
+ from dimos.control.blueprints import orchestrator_mock
+ coordinator = orchestrator_mock.build()
+ coordinator.loop()
+
+Example with trajectory setter:
+ # Terminal 1: Run the orchestrator
+ dimos run orchestrator-mock
+
+ # Terminal 2: Send trajectories via RPC
+ python -m dimos.control.examples.orchestrator_trajectory_setter --task traj_arm
+"""
+
+from __future__ import annotations
+
+from dimos.control.orchestrator import (
+ HardwareConfig,
+ TaskConfig,
+ control_orchestrator,
+)
+from dimos.core.transport import LCMTransport
+from dimos.msgs.sensor_msgs import JointState
+
+# =============================================================================
+# Helper function to generate joint names
+# =============================================================================
+
+
+def _joint_names(prefix: str, dof: int) -> list[str]:
+ """Generate joint names with prefix."""
+ return [f"{prefix}_joint{i + 1}" for i in range(dof)]
+
+
+# =============================================================================
+# Single Arm Blueprints
+# =============================================================================
+
+# Mock 7-DOF arm (for testing)
+orchestrator_mock = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="arm",
+ type="mock",
+ dof=7,
+ joint_prefix="arm",
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_arm",
+ type="trajectory",
+ joint_names=_joint_names("arm", 7),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# XArm7 real hardware (requires IP configuration)
+orchestrator_xarm7 = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="arm",
+ type="xarm",
+ dof=7,
+ joint_prefix="arm",
+ ip="192.168.2.235", # Default IP, override via env or config
+ auto_enable=True,
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_arm",
+ type="trajectory",
+ joint_names=_joint_names("arm", 7),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# XArm6 real hardware
+orchestrator_xarm6 = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="arm",
+ type="xarm",
+ dof=6,
+ joint_prefix="arm",
+ ip="192.168.1.210",
+ auto_enable=True,
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_xarm",
+ type="trajectory",
+ joint_names=_joint_names("arm", 6),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# Piper arm (6-DOF, CAN bus)
+orchestrator_piper = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="arm",
+ type="piper",
+ dof=6,
+ joint_prefix="arm",
+ can_port="can0",
+ auto_enable=True,
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_piper",
+ type="trajectory",
+ joint_names=_joint_names("arm", 6),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# =============================================================================
+# Dual Arm Blueprints
+# =============================================================================
+
+# Dual mock arms (7-DOF left, 6-DOF right) for testing
+orchestrator_dual_mock = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="left_arm",
+ type="mock",
+ dof=7,
+ joint_prefix="left",
+ ),
+ HardwareConfig(
+ id="right_arm",
+ type="mock",
+ dof=6,
+ joint_prefix="right",
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_left",
+ type="trajectory",
+ joint_names=_joint_names("left", 7),
+ priority=10,
+ ),
+ TaskConfig(
+ name="traj_right",
+ type="trajectory",
+ joint_names=_joint_names("right", 6),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# Dual XArm setup (XArm7 left, XArm6 right)
+orchestrator_dual_xarm = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="left_arm",
+ type="xarm",
+ dof=7,
+ joint_prefix="left",
+ ip="192.168.2.235",
+ auto_enable=True,
+ ),
+ HardwareConfig(
+ id="right_arm",
+ type="xarm",
+ dof=6,
+ joint_prefix="right",
+ ip="192.168.1.210",
+ auto_enable=True,
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_left",
+ type="trajectory",
+ joint_names=_joint_names("left", 7),
+ priority=10,
+ ),
+ TaskConfig(
+ name="traj_right",
+ type="trajectory",
+ joint_names=_joint_names("right", 6),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# Dual Arm setup (XArm6 , Piper )
+orchestrator_piper_xarm = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="xarm_arm",
+ type="xarm",
+ dof=6,
+ joint_prefix="xarm",
+ ip="192.168.1.210",
+ auto_enable=True,
+ ),
+ HardwareConfig(
+ id="piper_arm",
+ type="piper",
+ dof=6,
+ joint_prefix="piper",
+ can_port="can0",
+ auto_enable=True,
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_xarm",
+ type="trajectory",
+ joint_names=_joint_names("xarm", 6),
+ priority=10,
+ ),
+ TaskConfig(
+ name="traj_piper",
+ type="trajectory",
+ joint_names=_joint_names("piper", 6),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# =============================================================================
+# High-frequency Blueprints (200Hz)
+# =============================================================================
+
+# High-frequency mock for demanding applications
+orchestrator_highfreq_mock = control_orchestrator(
+ tick_rate=200.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+ hardware=[
+ HardwareConfig(
+ id="arm",
+ type="mock",
+ dof=7,
+ joint_prefix="arm",
+ ),
+ ],
+ tasks=[
+ TaskConfig(
+ name="traj_arm",
+ type="trajectory",
+ joint_names=_joint_names("arm", 7),
+ priority=10,
+ ),
+ ],
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+# =============================================================================
+# Raw Blueprints (no hardware/tasks configured - for programmatic setup)
+# =============================================================================
+
+# Basic orchestrator with transport only (add hardware/tasks programmatically)
+orchestrator_basic = control_orchestrator(
+ tick_rate=100.0,
+ publish_joint_state=True,
+ joint_state_frame_id="orchestrator",
+).transports(
+ {
+ ("joint_state", JointState): LCMTransport("/orchestrator/joint_state", JointState),
+ }
+)
+
+
+__all__ = [
+ # Raw blueprints (for programmatic setup)
+ "orchestrator_basic",
+ # Dual arm blueprints
+ "orchestrator_dual_mock",
+ "orchestrator_dual_xarm",
+ # High-frequency blueprints
+ "orchestrator_highfreq_mock",
+ # Single arm blueprints
+ "orchestrator_mock",
+ "orchestrator_piper",
+ "orchestrator_piper_xarm",
+ "orchestrator_xarm6",
+ "orchestrator_xarm7",
+]
diff --git a/dimos/control/hardware_interface.py b/dimos/control/hardware_interface.py
new file mode 100644
index 0000000000..ef62f974c6
--- /dev/null
+++ b/dimos/control/hardware_interface.py
@@ -0,0 +1,230 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Hardware interface for the ControlOrchestrator.
+
+Wraps ManipulatorBackend with orchestrator-specific features:
+- Namespaced joint names (e.g., "left_joint1")
+- Unified read/write interface
+- Hold-last-value for partial commands
+"""
+
+from __future__ import annotations
+
+import logging
+import time
+from typing import Protocol, runtime_checkable
+
+from dimos.hardware.manipulators.spec import ControlMode, ManipulatorBackend
+
+logger = logging.getLogger(__name__)
+
+
+@runtime_checkable
+class HardwareInterface(Protocol):
+ """Protocol for hardware that the orchestrator can control.
+
+ This wraps ManipulatorBackend with orchestrator-specific features:
+ - Namespaced joint names (e.g., "left_arm_joint1")
+ - Unified read/write interface
+ - State caching
+ """
+
+ @property
+ def hardware_id(self) -> str:
+ """Unique ID for this hardware (e.g., 'left_arm')."""
+ ...
+
+ @property
+ def joint_names(self) -> list[str]:
+ """Ordered list of fully-qualified joint names this hardware controls."""
+ ...
+
+ def read_state(self) -> dict[str, tuple[float, float, float]]:
+ """Read current state.
+
+ Returns:
+ Dict of joint_name -> (position, velocity, effort)
+ """
+ ...
+
+ def write_command(self, commands: dict[str, float], mode: ControlMode) -> bool:
+ """Write commands to hardware.
+
+ IMPORTANT: Accepts partial joint sets. Missing joints hold last value.
+
+ Args:
+ commands: {joint_name: value} - can be partial
+ mode: Control mode (POSITION, VELOCITY, TORQUE)
+
+ Returns:
+ True if command was sent successfully
+ """
+ ...
+
+ def disconnect(self) -> None:
+ """Disconnect the underlying hardware."""
+ ...
+
+
+class BackendHardwareInterface:
+ """Concrete implementation wrapping a ManipulatorBackend.
+
+ Features:
+ - Generates namespaced joint names (prefix_joint1, prefix_joint2, ...)
+ - Holds last commanded value for partial commands
+ - On first tick, reads current position from hardware for missing joints
+ """
+
+ def __init__(
+ self,
+ backend: ManipulatorBackend,
+ hardware_id: str,
+ joint_prefix: str | None = None,
+ ) -> None:
+ """Initialize hardware interface.
+
+ Args:
+ backend: ManipulatorBackend instance (XArmBackend, PiperBackend, etc.)
+ hardware_id: Unique identifier for this hardware
+ joint_prefix: Prefix for joint names (defaults to hardware_id)
+ """
+ if not isinstance(backend, ManipulatorBackend):
+ raise TypeError("backend must implement ManipulatorBackend")
+
+ self._backend = backend
+ self._hardware_id = hardware_id
+ self._prefix = joint_prefix or hardware_id
+ self._dof = backend.get_dof()
+
+ # Generate joint names: prefix_joint1, prefix_joint2, ...
+ self._joint_names = [f"{self._prefix}_joint{i + 1}" for i in range(self._dof)]
+
+ # Track last commanded values for hold-last behavior
+ self._last_commanded: dict[str, float] = {}
+ self._initialized = False
+ self._warned_unknown_joints: set[str] = set()
+ self._current_mode: ControlMode | None = None
+
+ @property
+ def hardware_id(self) -> str:
+ """Unique ID for this hardware."""
+ return self._hardware_id
+
+ @property
+ def joint_names(self) -> list[str]:
+ """Ordered list of joint names."""
+ return self._joint_names
+
+ @property
+ def dof(self) -> int:
+ """Degrees of freedom."""
+ return self._dof
+
+ def disconnect(self) -> None:
+ """Disconnect the underlying backend."""
+ self._backend.disconnect()
+
+ def read_state(self) -> dict[str, tuple[float, float, float]]:
+ """Read state as {joint_name: (position, velocity, effort)}.
+
+ Returns:
+ Dict mapping joint name to (position, velocity, effort) tuple
+ """
+ positions = self._backend.read_joint_positions()
+ velocities = self._backend.read_joint_velocities()
+ efforts = self._backend.read_joint_efforts()
+
+ return {
+ name: (positions[i], velocities[i], efforts[i])
+ for i, name in enumerate(self._joint_names)
+ }
+
+ def write_command(self, commands: dict[str, float], mode: ControlMode) -> bool:
+ """Write commands - allows partial joint sets, holds last for missing.
+
+ This is critical for:
+ - Partial WBC overrides
+ - Safety controllers
+ - Mixed task ownership
+
+ Args:
+ commands: {joint_name: value} - can be partial
+ mode: Control mode
+
+ Returns:
+ True if command was sent successfully
+ """
+ # Initialize on first write if needed
+ if not self._initialized:
+ self._initialize_last_commanded()
+
+ # Update last commanded for joints we received
+ for joint_name, value in commands.items():
+ if joint_name in self._joint_names:
+ self._last_commanded[joint_name] = value
+ elif joint_name not in self._warned_unknown_joints:
+ logger.warning(
+ f"Hardware {self._hardware_id} received command for unknown joint "
+ f"{joint_name}. Valid joints: {self._joint_names}"
+ )
+ self._warned_unknown_joints.add(joint_name)
+
+ # Build ordered list for backend
+ ordered = self._build_ordered_command()
+
+ # Switch control mode if needed
+ if mode != self._current_mode:
+ if not self._backend.set_control_mode(mode):
+ logger.warning(f"Hardware {self._hardware_id} failed to switch to {mode.name}")
+ return False
+ self._current_mode = mode
+
+ # Send to backend
+ match mode:
+ case ControlMode.POSITION | ControlMode.SERVO_POSITION:
+ return self._backend.write_joint_positions(ordered)
+ case ControlMode.VELOCITY:
+ return self._backend.write_joint_velocities(ordered)
+ case ControlMode.TORQUE:
+ logger.warning(f"Hardware {self._hardware_id} does not support torque mode")
+ return False
+ case _:
+ return False
+
+ def _initialize_last_commanded(self) -> None:
+ """Initialize last_commanded with current hardware positions."""
+ for _ in range(10):
+ try:
+ current = self._backend.read_joint_positions()
+ for i, name in enumerate(self._joint_names):
+ self._last_commanded[name] = current[i]
+ self._initialized = True
+ return
+ except Exception:
+ time.sleep(0.01)
+
+ raise RuntimeError(
+ f"Hardware {self._hardware_id} failed to read initial positions after retries"
+ )
+
+ def _build_ordered_command(self) -> list[float]:
+ """Build ordered command list from last_commanded dict."""
+ return [self._last_commanded[name] for name in self._joint_names]
+
+
+__all__ = [
+ "BackendHardwareInterface",
+ "HardwareInterface",
+]
diff --git a/dimos/control/orchestrator.py b/dimos/control/orchestrator.py
new file mode 100644
index 0000000000..2d64620b13
--- /dev/null
+++ b/dimos/control/orchestrator.py
@@ -0,0 +1,538 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ControlOrchestrator module.
+
+Centralized control orchestrator that replaces per-driver/per-controller
+loops with a single deterministic tick-based system.
+
+Features:
+- Single tick loop (read → compute → arbitrate → route → write)
+- Per-joint arbitration (highest priority wins)
+- Mode conflict detection
+- Partial command support (hold last value)
+- Aggregated preemption notifications
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+import threading
+import time
+from typing import TYPE_CHECKING, Any
+
+from dimos.control.hardware_interface import BackendHardwareInterface, HardwareInterface
+from dimos.control.task import ControlTask
+from dimos.control.tick_loop import TickLoop
+from dimos.core import Module, Out, rpc
+from dimos.core.module import ModuleConfig
+from dimos.msgs.sensor_msgs import (
+ JointState, # noqa: TC001 - needed at runtime for Out[JointState]
+)
+from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryState
+from dimos.utils.logging_config import setup_logger
+
+if TYPE_CHECKING:
+ from dimos.hardware.manipulators.spec import ManipulatorBackend
+
+logger = setup_logger()
+
+
+# =============================================================================
+# Configuration
+# =============================================================================
+
+
+@dataclass
+class HardwareConfig:
+ """Configuration for a hardware backend.
+
+ Attributes:
+ id: Unique hardware identifier (e.g., "arm", "left_arm")
+ type: Backend type ("mock", "xarm", "piper")
+ dof: Degrees of freedom
+ joint_prefix: Prefix for joint names (defaults to id)
+ ip: IP address (required for xarm)
+ can_port: CAN port (for piper, default "can0")
+ auto_enable: Whether to auto-enable servos (default True)
+ """
+
+ id: str
+ type: str = "mock"
+ dof: int = 7
+ joint_prefix: str | None = None
+ ip: str | None = None
+ can_port: str | None = None
+ auto_enable: bool = True
+
+
+@dataclass
+class TaskConfig:
+ """Configuration for a control task.
+
+ Attributes:
+ name: Task name (e.g., "traj_arm")
+ type: Task type ("trajectory")
+ joint_names: List of joint names this task controls
+ priority: Task priority (higher wins arbitration)
+ """
+
+ name: str
+ type: str = "trajectory"
+ joint_names: list[str] = field(default_factory=lambda: [])
+ priority: int = 10
+
+
+@dataclass
+class TaskStatus:
+ """Status of a control task.
+
+ Attributes:
+ active: Whether the task is currently active
+ state: Task state name (e.g., "IDLE", "RUNNING", "DONE")
+ progress: Task progress from 0.0 to 1.0
+ """
+
+ active: bool
+ state: str | None = None
+ progress: float | None = None
+
+
+@dataclass
+class ControlOrchestratorConfig(ModuleConfig):
+ """Configuration for the ControlOrchestrator.
+
+ Attributes:
+ tick_rate: Control loop frequency in Hz (default: 100)
+ publish_joint_state: Whether to publish aggregated JointState
+ joint_state_frame_id: Frame ID for published JointState
+ log_ticks: Whether to log tick information (verbose)
+ hardware: List of hardware configurations to create on start
+ tasks: List of task configurations to create on start
+ """
+
+ tick_rate: float = 100.0
+ publish_joint_state: bool = True
+ joint_state_frame_id: str = "orchestrator"
+ log_ticks: bool = False
+ hardware: list[HardwareConfig] = field(default_factory=lambda: [])
+ tasks: list[TaskConfig] = field(default_factory=lambda: [])
+
+
+# =============================================================================
+# ControlOrchestrator Module
+# =============================================================================
+
+
+class ControlOrchestrator(Module[ControlOrchestratorConfig]):
+ """Centralized control orchestrator with per-joint arbitration.
+
+ Single tick loop that:
+ 1. Reads state from all hardware
+ 2. Runs all active tasks
+ 3. Arbitrates conflicts per-joint (highest priority wins)
+ 4. Routes commands to hardware
+ 5. Publishes aggregated joint state
+
+ Key design decisions:
+ - Joint-centric commands (not hardware-centric)
+ - Per-joint arbitration (not per-hardware)
+ - Centralized time (tasks use state.t_now, never time.time())
+ - Partial commands OK (hardware holds last value)
+ - Aggregated preemption (one notification per task per tick)
+
+ Example:
+ >>> from dimos.control import ControlOrchestrator
+ >>> from dimos.hardware.manipulators.xarm import XArmBackend
+ >>>
+ >>> orch = ControlOrchestrator(tick_rate=100.0)
+ >>> backend = XArmBackend(ip="192.168.1.185", dof=7)
+ >>> backend.connect()
+ >>> orch.add_hardware("left_arm", backend, joint_prefix="left")
+ >>> orch.start()
+ """
+
+ # Output: Aggregated joint state for external consumers
+ joint_state: Out[JointState]
+
+ config: ControlOrchestratorConfig
+ default_config = ControlOrchestratorConfig
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+
+ # Hardware interfaces (keyed by hardware_id)
+ self._hardware: dict[str, HardwareInterface] = {}
+ self._hardware_lock = threading.Lock()
+
+ # Joint -> hardware mapping (built when hardware added)
+ self._joint_to_hardware: dict[str, str] = {}
+
+ # Registered tasks
+ self._tasks: dict[str, ControlTask] = {}
+ self._task_lock = threading.Lock()
+
+ # Tick loop (created on start)
+ self._tick_loop: TickLoop | None = None
+
+ logger.info(f"ControlOrchestrator initialized at {self.config.tick_rate}Hz")
+
+ # =========================================================================
+ # Config-based Setup
+ # =========================================================================
+
+ def _setup_from_config(self) -> None:
+ """Create hardware and tasks from config (called on start)."""
+ hardware_added: list[str] = []
+
+ try:
+ for hw_cfg in self.config.hardware:
+ self._setup_hardware(hw_cfg)
+ hardware_added.append(hw_cfg.id)
+
+ for task_cfg in self.config.tasks:
+ task = self._create_task_from_config(task_cfg)
+ self.add_task(task)
+
+ except Exception:
+ # Rollback: clean up all successfully added hardware
+ for hw_id in hardware_added:
+ try:
+ self.remove_hardware(hw_id)
+ except Exception:
+ pass
+ raise
+
+ def _setup_hardware(self, hw_cfg: HardwareConfig) -> None:
+ """Connect and add a single hardware backend."""
+ backend = self._create_backend_from_config(hw_cfg)
+
+ if not backend.connect():
+ raise RuntimeError(f"Failed to connect to {hw_cfg.type} backend")
+
+ try:
+ if hw_cfg.auto_enable and hasattr(backend, "write_enable"):
+ backend.write_enable(True)
+ self.add_hardware(
+ hw_cfg.id,
+ backend,
+ joint_prefix=hw_cfg.joint_prefix or hw_cfg.id,
+ )
+ except Exception:
+ backend.disconnect()
+ raise
+
+ def _create_backend_from_config(self, cfg: HardwareConfig) -> ManipulatorBackend:
+ """Create a manipulator backend from config."""
+ match cfg.type.lower():
+ case "mock":
+ from dimos.hardware.manipulators.mock import MockBackend
+
+ return MockBackend(dof=cfg.dof)
+ case "xarm":
+ if cfg.ip is None:
+ raise ValueError("ip is required for xarm backend")
+ from dimos.hardware.manipulators.xarm import XArmBackend
+
+ return XArmBackend(ip=cfg.ip, dof=cfg.dof)
+ case "piper":
+ from dimos.hardware.manipulators.piper import PiperBackend
+
+ return PiperBackend(can_port=cfg.can_port or "can0", dof=cfg.dof)
+ case _:
+ raise ValueError(f"Unknown backend type: {cfg.type}")
+
+ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask:
+ """Create a control task from config."""
+ task_type = cfg.type.lower()
+
+ if task_type == "trajectory":
+ from dimos.control.tasks import JointTrajectoryTask, JointTrajectoryTaskConfig
+
+ return JointTrajectoryTask(
+ cfg.name,
+ JointTrajectoryTaskConfig(
+ joint_names=cfg.joint_names,
+ priority=cfg.priority,
+ ),
+ )
+
+ else:
+ raise ValueError(f"Unknown task type: {task_type}")
+
+ # =========================================================================
+ # Hardware Management (RPC)
+ # =========================================================================
+
+ @rpc
+ def add_hardware(
+ self,
+ hardware_id: str,
+ backend: ManipulatorBackend,
+ joint_prefix: str | None = None,
+ ) -> bool:
+ """Register a hardware backend with the orchestrator."""
+ with self._hardware_lock:
+ if hardware_id in self._hardware:
+ logger.warning(f"Hardware {hardware_id} already registered")
+ return False
+
+ interface = BackendHardwareInterface(
+ backend=backend,
+ hardware_id=hardware_id,
+ joint_prefix=joint_prefix,
+ )
+ self._hardware[hardware_id] = interface
+
+ for joint_name in interface.joint_names:
+ self._joint_to_hardware[joint_name] = hardware_id
+
+ logger.info(f"Added hardware {hardware_id} with joints: {interface.joint_names}")
+ return True
+
+ @rpc
+ def remove_hardware(self, hardware_id: str) -> bool:
+ """Remove a hardware interface.
+
+ Note: For safety, call this only when no tasks are actively using this
+ hardware. Consider stopping the orchestrator before removing hardware.
+ """
+ with self._hardware_lock:
+ if hardware_id not in self._hardware:
+ return False
+
+ interface = self._hardware[hardware_id]
+ hw_joints = set(interface.joint_names)
+
+ with self._task_lock:
+ for task in self._tasks.values():
+ if task.is_active():
+ claimed_joints = task.claim().joints
+ overlap = hw_joints & claimed_joints
+ if overlap:
+ logger.error(
+ f"Cannot remove hardware {hardware_id}: "
+ f"task '{task.name}' is actively using joints {overlap}"
+ )
+ return False
+
+ for joint_name in interface.joint_names:
+ del self._joint_to_hardware[joint_name]
+
+ interface.disconnect()
+ del self._hardware[hardware_id]
+ logger.info(f"Removed hardware {hardware_id}")
+ return True
+
+ @rpc
+ def list_hardware(self) -> list[str]:
+ """List registered hardware IDs."""
+ with self._hardware_lock:
+ return list(self._hardware.keys())
+
+ @rpc
+ def list_joints(self) -> list[str]:
+ """List all joint names across all hardware."""
+ with self._hardware_lock:
+ return list(self._joint_to_hardware.keys())
+
+ @rpc
+ def get_joint_positions(self) -> dict[str, float]:
+ """Get current joint positions for all joints."""
+ with self._hardware_lock:
+ positions: dict[str, float] = {}
+ for hw in self._hardware.values():
+ state = hw.read_state() # {joint_name: (pos, vel, effort)}
+ for joint_name, (pos, _vel, _effort) in state.items():
+ positions[joint_name] = pos
+ return positions
+
+ # =========================================================================
+ # Task Management (RPC)
+ # =========================================================================
+
+ @rpc
+ def add_task(self, task: ControlTask) -> bool:
+ """Register a task with the orchestrator."""
+ if not isinstance(task, ControlTask):
+ raise TypeError("task must implement ControlTask")
+
+ with self._task_lock:
+ if task.name in self._tasks:
+ logger.warning(f"Task {task.name} already registered")
+ return False
+ self._tasks[task.name] = task
+ logger.info(f"Added task {task.name}")
+ return True
+
+ @rpc
+ def remove_task(self, task_name: str) -> bool:
+ """Remove a task by name."""
+ with self._task_lock:
+ if task_name in self._tasks:
+ del self._tasks[task_name]
+ logger.info(f"Removed task {task_name}")
+ return True
+ return False
+
+ @rpc
+ def get_task(self, task_name: str) -> ControlTask | None:
+ """Get a task by name."""
+ with self._task_lock:
+ return self._tasks.get(task_name)
+
+ @rpc
+ def list_tasks(self) -> list[str]:
+ """List registered task names."""
+ with self._task_lock:
+ return list(self._tasks.keys())
+
+ @rpc
+ def get_active_tasks(self) -> list[str]:
+ """List currently active task names."""
+ with self._task_lock:
+ return [name for name, task in self._tasks.items() if task.is_active()]
+
+ # =========================================================================
+ # Trajectory Execution (RPC)
+ # =========================================================================
+
+ @rpc
+ def execute_trajectory(self, task_name: str, trajectory: JointTrajectory) -> bool:
+ """Execute a trajectory on a named task."""
+ with self._task_lock:
+ task = self._tasks.get(task_name)
+ if task is None:
+ logger.warning(f"Task {task_name} not found")
+ return False
+
+ if not hasattr(task, "execute"):
+ logger.warning(f"Task {task_name} doesn't support execute()")
+ return False
+
+ logger.info(
+ f"Executing trajectory on {task_name}: "
+ f"{len(trajectory.points)} points, duration={trajectory.duration:.3f}s"
+ )
+ return task.execute(trajectory) # type: ignore[attr-defined,no-any-return]
+
+ @rpc
+ def get_trajectory_status(self, task_name: str) -> TaskStatus | None:
+ """Get the status of a trajectory task."""
+ with self._task_lock:
+ task = self._tasks.get(task_name)
+ if task is None:
+ return None
+
+ state: str | None = None
+ if hasattr(task, "get_state"):
+ task_state: TrajectoryState = task.get_state() # type: ignore[attr-defined]
+ state = (
+ task_state.name if isinstance(task_state, TrajectoryState) else str(task_state)
+ )
+
+ progress: float | None = None
+ if hasattr(task, "get_progress"):
+ t_now = time.perf_counter()
+ progress = task.get_progress(t_now) # type: ignore[attr-defined]
+
+ return TaskStatus(active=task.is_active(), state=state, progress=progress)
+
+ @rpc
+ def cancel_trajectory(self, task_name: str) -> bool:
+ """Cancel an active trajectory on a task."""
+ with self._task_lock:
+ task = self._tasks.get(task_name)
+ if task is None:
+ logger.warning(f"Task {task_name} not found")
+ return False
+
+ if not hasattr(task, "cancel"):
+ logger.warning(f"Task {task_name} doesn't support cancel()")
+ return False
+
+ logger.info(f"Cancelling trajectory on {task_name}")
+ return task.cancel() # type: ignore[attr-defined,no-any-return]
+
+ # =========================================================================
+ # Lifecycle
+ # =========================================================================
+
+ @rpc
+ def start(self) -> None:
+ """Start the orchestrator control loop."""
+ if self._tick_loop and self._tick_loop.is_running:
+ logger.warning("Orchestrator already running")
+ return
+
+ super().start()
+
+ # Setup hardware and tasks from config (if any)
+ if self.config.hardware or self.config.tasks:
+ self._setup_from_config()
+
+ # Create and start tick loop
+ publish_cb = self.joint_state.publish if self.config.publish_joint_state else None
+ self._tick_loop = TickLoop(
+ tick_rate=self.config.tick_rate,
+ hardware=self._hardware,
+ hardware_lock=self._hardware_lock,
+ tasks=self._tasks,
+ task_lock=self._task_lock,
+ joint_to_hardware=self._joint_to_hardware,
+ publish_callback=publish_cb,
+ frame_id=self.config.joint_state_frame_id,
+ log_ticks=self.config.log_ticks,
+ )
+ self._tick_loop.start()
+
+ logger.info(f"ControlOrchestrator started at {self.config.tick_rate}Hz")
+
+ @rpc
+ def stop(self) -> None:
+ """Stop the orchestrator."""
+ logger.info("Stopping ControlOrchestrator...")
+
+ if self._tick_loop:
+ self._tick_loop.stop()
+
+ # Disconnect all hardware backends
+ with self._hardware_lock:
+ for hw_id, interface in self._hardware.items():
+ try:
+ interface.disconnect()
+ logger.info(f"Disconnected hardware {hw_id}")
+ except Exception as e:
+ logger.error(f"Error disconnecting hardware {hw_id}: {e}")
+
+ super().stop()
+ logger.info("ControlOrchestrator stopped")
+
+ @rpc
+ def get_tick_count(self) -> int:
+ """Get the number of ticks since start."""
+ return self._tick_loop.tick_count if self._tick_loop else 0
+
+
+# Blueprint export
+control_orchestrator = ControlOrchestrator.blueprint
+
+
+__all__ = [
+ "ControlOrchestrator",
+ "ControlOrchestratorConfig",
+ "HardwareConfig",
+ "TaskConfig",
+ "control_orchestrator",
+]
diff --git a/dimos/control/task.py b/dimos/control/task.py
new file mode 100644
index 0000000000..49589188d9
--- /dev/null
+++ b/dimos/control/task.py
@@ -0,0 +1,299 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ControlTask protocol and types for the ControlOrchestrator.
+
+This module defines:
+- Data types used by tasks and the orchestrator (ResourceClaim, JointStateSnapshot, etc.)
+- ControlTask protocol that all tasks must implement
+
+Tasks are "passive" - they don't own threads. The orchestrator calls
+compute() at each tick, passing current state and time.
+
+CRITICAL: Tasks must NEVER call time.time() directly.
+Use the t_now passed in OrchestratorState.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Protocol, runtime_checkable
+
+from dimos.hardware.manipulators.spec import ControlMode
+
+# =============================================================================
+# Data Types
+# =============================================================================
+
+
+@dataclass(frozen=True)
+class ResourceClaim:
+ """Declares which joints a task wants to control.
+
+ Used by the orchestrator to determine resource ownership and
+ resolve conflicts between competing tasks.
+
+ Attributes:
+ joints: Set of joint names this task wants to control.
+ Example: frozenset({"left_joint1", "left_joint2"})
+ priority: Priority level for conflict resolution. Higher wins.
+ Typical values: 10 (trajectory), 50 (WBC), 100 (safety)
+ mode: Control mode (POSITION, VELOCITY, TORQUE)
+ """
+
+ joints: frozenset[str]
+ priority: int = 0
+ mode: ControlMode = ControlMode.POSITION
+
+ def conflicts_with(self, other: ResourceClaim) -> bool:
+ """Check if two claims compete for the same joints."""
+ return bool(self.joints & other.joints)
+
+
+@dataclass
+class JointStateSnapshot:
+ """Aggregated joint states from all hardware.
+
+ Provides a unified view of all joint states across all hardware
+ interfaces, indexed by fully-qualified joint name.
+
+ Attributes:
+ joint_positions: Joint name -> position (radians)
+ joint_velocities: Joint name -> velocity (rad/s)
+ joint_efforts: Joint name -> effort (Nm)
+ timestamp: Unix timestamp when state was read
+ """
+
+ joint_positions: dict[str, float] = field(default_factory=dict)
+ joint_velocities: dict[str, float] = field(default_factory=dict)
+ joint_efforts: dict[str, float] = field(default_factory=dict)
+ timestamp: float = 0.0
+
+ def get_position(self, joint_name: str) -> float | None:
+ """Get position for a specific joint."""
+ return self.joint_positions.get(joint_name)
+
+ def get_velocity(self, joint_name: str) -> float | None:
+ """Get velocity for a specific joint."""
+ return self.joint_velocities.get(joint_name)
+
+ def get_effort(self, joint_name: str) -> float | None:
+ """Get effort for a specific joint."""
+ return self.joint_efforts.get(joint_name)
+
+
+@dataclass
+class OrchestratorState:
+ """Complete state snapshot for tasks to read.
+
+ Passed to each task's compute() method every tick. Contains
+ all information a task needs to compute its output.
+
+ CRITICAL: Tasks should use t_now for timing, never time.time()!
+
+ Attributes:
+ joints: Aggregated joint states from all hardware
+ t_now: Current tick time (time.perf_counter())
+ dt: Time since last tick (seconds)
+ """
+
+ joints: JointStateSnapshot
+ t_now: float # Orchestrator time (perf_counter) - USE THIS, NOT time.time()!
+ dt: float # Time since last tick
+
+
+@dataclass
+class JointCommandOutput:
+ """Joint-centric command output from a task.
+
+ Commands are addressed by joint name, NOT by hardware ID.
+ The orchestrator routes commands to the appropriate hardware.
+
+ This design enables:
+ - WBC spanning multiple hardware interfaces
+ - Partial joint ownership
+ - Per-joint arbitration
+
+ Attributes:
+ joint_names: Which joints this command is for
+ positions: Position commands (radians), or None
+ velocities: Velocity commands (rad/s), or None
+ efforts: Effort commands (Nm), or None
+ mode: Control mode - must match which field is populated
+ """
+
+ joint_names: list[str]
+ positions: list[float] | None = None
+ velocities: list[float] | None = None
+ efforts: list[float] | None = None
+ mode: ControlMode = ControlMode.POSITION
+
+ def __post_init__(self) -> None:
+ """Validate that lengths match and at least one value field is set."""
+ n = len(self.joint_names)
+
+ if self.positions is not None and len(self.positions) != n:
+ raise ValueError(f"positions length {len(self.positions)} != joint_names length {n}")
+ if self.velocities is not None and len(self.velocities) != n:
+ raise ValueError(f"velocities length {len(self.velocities)} != joint_names length {n}")
+ if self.efforts is not None and len(self.efforts) != n:
+ raise ValueError(f"efforts length {len(self.efforts)} != joint_names length {n}")
+
+ def get_values(self) -> list[float] | None:
+ """Get the active values based on mode."""
+ match self.mode:
+ case ControlMode.POSITION | ControlMode.SERVO_POSITION:
+ return self.positions
+ case ControlMode.VELOCITY:
+ return self.velocities
+ case ControlMode.TORQUE:
+ return self.efforts
+ case _:
+ return None
+
+
+# =============================================================================
+# ControlTask Protocol
+# =============================================================================
+
+
+@runtime_checkable
+class ControlTask(Protocol):
+ """Protocol for passive tasks that run within the orchestrator.
+
+ Tasks are "passive" - they don't own threads. The orchestrator
+ calls compute() at each tick, passing current state and time.
+
+ Lifecycle:
+ 1. Task is added to orchestrator via add_task()
+ 2. Orchestrator calls claim() to understand resource needs
+ 3. Each tick: is_active() → compute() → output merged via arbitration
+ 4. Task removed via remove_task() or transitions to inactive
+
+ CRITICAL: Tasks must NEVER call time.time() directly.
+ Use state.t_now passed to compute() for all timing.
+
+ Example:
+ >>> class MyTask:
+ ... @property
+ ... def name(self) -> str:
+ ... return "my_task"
+ ...
+ ... def claim(self) -> ResourceClaim:
+ ... return ResourceClaim(
+ ... joints=frozenset(["left_joint1", "left_joint2"]),
+ ... priority=10,
+ ... )
+ ...
+ ... def is_active(self) -> bool:
+ ... return self._executing
+ ...
+ ... def compute(self, state: OrchestratorState) -> JointCommandOutput | None:
+ ... # Use state.t_now, NOT time.time()!
+ ... t_elapsed = state.t_now - self._start_time
+ ... positions = self._trajectory.sample(t_elapsed)
+ ... return JointCommandOutput(
+ ... joint_names=["left_joint1", "left_joint2"],
+ ... positions=positions,
+ ... )
+ ...
+ ... def on_preempted(self, by_task: str, joints: frozenset[str]) -> None:
+ ... print(f"Preempted by {by_task} on {joints}")
+ """
+
+ @property
+ def name(self) -> str:
+ """Unique identifier for this task instance.
+
+ Used for logging, debugging, and task management.
+ Must be unique across all tasks in the orchestrator.
+ """
+ ...
+
+ def claim(self) -> ResourceClaim:
+ """Declare resource requirements.
+
+ Called by orchestrator to determine:
+ - Which joints this task wants to control
+ - Priority for conflict resolution
+ - Control mode (position/velocity/effort)
+
+ Returns:
+ ResourceClaim with joints (frozenset) and priority (int)
+
+ Note:
+ The claim can change dynamically - orchestrator calls this
+ every tick for active tasks.
+ """
+ ...
+
+ def is_active(self) -> bool:
+ """Check if task should run this tick.
+
+ Inactive tasks:
+ - Skip compute() call
+ - Don't participate in arbitration
+ - Don't consume resources
+
+ Returns:
+ True if task should execute this tick
+ """
+ ...
+
+ def compute(self, state: OrchestratorState) -> JointCommandOutput | None:
+ """Compute output command given current state.
+
+ Called by orchestrator for active tasks each tick.
+
+ CRITICAL: Use state.t_now for timing, NEVER time.time()!
+ This ensures deterministic behavior and enables simulation.
+
+ Args:
+ state: OrchestratorState containing:
+ - joints: JointStateSnapshot with all joint states
+ - t_now: Current tick time (use this for all timing!)
+ - dt: Time since last tick
+
+ Returns:
+ JointCommandOutput with joint_names and values, or None if
+ no command should be sent this tick.
+ """
+ ...
+
+ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None:
+ """Called ONCE per tick with ALL preempted joints aggregated.
+
+ Called when a higher-priority task takes control of some of this
+ task's joints. Allows task to gracefully handle being overridden.
+
+ This is called ONCE per tick with ALL preempted joints, not once
+ per joint. This reduces noise and improves performance.
+
+ Args:
+ by_task: Name of the preempting task (or "arbitration" if multiple)
+ joints: All joints that were preempted this tick
+ """
+ ...
+
+
+__all__ = [
+ # Types
+ "ControlMode",
+ # Protocol
+ "ControlTask",
+ "JointCommandOutput",
+ "JointStateSnapshot",
+ "OrchestratorState",
+ "ResourceClaim",
+]
diff --git a/dimos/control/tasks/__init__.py b/dimos/control/tasks/__init__.py
new file mode 100644
index 0000000000..75460ffa26
--- /dev/null
+++ b/dimos/control/tasks/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Task implementations for the ControlOrchestrator."""
+
+from dimos.control.tasks.trajectory_task import (
+ JointTrajectoryTask,
+ JointTrajectoryTaskConfig,
+)
+
+__all__ = [
+ "JointTrajectoryTask",
+ "JointTrajectoryTaskConfig",
+]
diff --git a/dimos/control/tasks/trajectory_task.py b/dimos/control/tasks/trajectory_task.py
new file mode 100644
index 0000000000..08e3ae337e
--- /dev/null
+++ b/dimos/control/tasks/trajectory_task.py
@@ -0,0 +1,261 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Joint trajectory task for the ControlOrchestrator.
+
+Passive trajectory execution - called by orchestrator each tick.
+Unlike JointTrajectoryController which owns a thread, this task
+is compute-only and relies on the orchestrator for timing.
+
+CRITICAL: Uses t_now from OrchestratorState, never calls time.time()
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from dimos.control.task import (
+ ControlMode,
+ ControlTask,
+ JointCommandOutput,
+ OrchestratorState,
+ ResourceClaim,
+)
+from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryState
+from dimos.utils.logging_config import setup_logger
+
+logger = setup_logger()
+
+
+@dataclass
+class JointTrajectoryTaskConfig:
+ """Configuration for trajectory task.
+
+ Attributes:
+ joint_names: List of joint names this task controls
+ priority: Priority for arbitration (higher wins)
+ """
+
+ joint_names: list[str]
+ priority: int = 10
+
+
+class JointTrajectoryTask(ControlTask):
+ """Passive trajectory execution task.
+
+ Unlike JointTrajectoryController which owns a thread, this task
+ is called by the orchestrator at each tick.
+
+ CRITICAL: Uses t_now from OrchestratorState, never calls time.time()
+
+ State Machine:
+ IDLE ──execute()──► EXECUTING ──done──► COMPLETED
+ ▲ │ │
+ │ cancel() reset()
+ │ ▼ │
+ └─────reset()───── ABORTED ◄──────────────┘
+
+ Example:
+ >>> task = JointTrajectoryTask(
+ ... name="traj_left",
+ ... config=JointTrajectoryTaskConfig(
+ ... joint_names=["left_joint1", "left_joint2"],
+ ... priority=10,
+ ... ),
+ ... )
+ >>> orchestrator.add_task(task)
+ >>> task.execute(my_trajectory, t_now=orchestrator_t_now)
+ """
+
+ def __init__(self, name: str, config: JointTrajectoryTaskConfig) -> None:
+ """Initialize trajectory task.
+
+ Args:
+ name: Unique task name
+ config: Task configuration
+ """
+ if not config.joint_names:
+ raise ValueError(f"JointTrajectoryTask '{name}' requires at least one joint")
+ self._name = name
+ self._config = config
+ self._joint_names = frozenset(config.joint_names)
+ self._joint_names_list = list(config.joint_names)
+
+ # State machine
+ self._state = TrajectoryState.IDLE
+ self._trajectory: JointTrajectory | None = None
+ self._start_time: float = 0.0
+ self._pending_start: bool = False # Defer start time to first compute()
+
+ logger.info(f"JointTrajectoryTask {name} initialized for joints: {config.joint_names}")
+
+ @property
+ def name(self) -> str:
+ """Unique task identifier."""
+ return self._name
+
+ def claim(self) -> ResourceClaim:
+ """Declare resource requirements."""
+ return ResourceClaim(
+ joints=self._joint_names,
+ priority=self._config.priority,
+ mode=ControlMode.SERVO_POSITION,
+ )
+
+ def is_active(self) -> bool:
+ """Check if task should run this tick."""
+ return self._state == TrajectoryState.EXECUTING
+
+ def compute(self, state: OrchestratorState) -> JointCommandOutput | None:
+ """Compute trajectory output for this tick.
+
+ CRITICAL: Uses state.t_now for timing, NOT time.time()!
+
+ Args:
+ state: Current orchestrator state
+
+ Returns:
+ JointCommandOutput with positions, or None if not executing
+ """
+ if self._trajectory is None:
+ return None
+
+ # Set start time on first compute() for consistent timing
+ if self._pending_start:
+ self._start_time = state.t_now
+ self._pending_start = False
+
+ t_elapsed = state.t_now - self._start_time
+
+ # Check completion - clamp to final position to ensure we reach goal
+ if t_elapsed >= self._trajectory.duration:
+ self._state = TrajectoryState.COMPLETED
+ logger.info(f"Trajectory {self._name} completed after {t_elapsed:.3f}s")
+ # Return final position to hold at goal
+ q_ref, _ = self._trajectory.sample(self._trajectory.duration)
+ return JointCommandOutput(
+ joint_names=self._joint_names_list,
+ positions=list(q_ref),
+ mode=ControlMode.SERVO_POSITION,
+ )
+
+ # Sample trajectory
+ q_ref, _ = self._trajectory.sample(t_elapsed)
+
+ return JointCommandOutput(
+ joint_names=self._joint_names_list,
+ positions=list(q_ref),
+ mode=ControlMode.SERVO_POSITION,
+ )
+
+ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None:
+ """Handle preemption by higher-priority task.
+
+ Args:
+ by_task: Name of preempting task
+ joints: Joints that were preempted
+ """
+ logger.warning(f"Trajectory {self._name} preempted by {by_task} on joints {joints}")
+ # Abort if any of our joints were preempted
+ if joints & self._joint_names:
+ self._state = TrajectoryState.ABORTED
+
+ # =========================================================================
+ # Task-specific methods
+ # =========================================================================
+
+ def execute(self, trajectory: JointTrajectory) -> bool:
+ """Start executing a trajectory.
+
+ Args:
+ trajectory: Trajectory to execute
+
+ Returns:
+ True if accepted, False if invalid or in FAULT state
+ """
+ if self._state == TrajectoryState.FAULT:
+ logger.warning(f"Cannot execute: {self._name} in FAULT state")
+ return False
+
+ if trajectory is None or trajectory.duration <= 0:
+ logger.warning(f"Invalid trajectory for {self._name}")
+ return False
+
+ if not trajectory.points:
+ logger.warning(f"Empty trajectory for {self._name}")
+ return False
+
+ # Preempt any active trajectory
+ if self._state == TrajectoryState.EXECUTING:
+ logger.info(f"Preempting active trajectory on {self._name}")
+
+ self._trajectory = trajectory
+ self._pending_start = True # Start time set on first compute()
+ self._state = TrajectoryState.EXECUTING
+
+ logger.info(
+ f"Executing trajectory on {self._name}: "
+ f"{len(trajectory.points)} points, duration={trajectory.duration:.3f}s"
+ )
+ return True
+
+ def cancel(self) -> bool:
+ """Cancel current trajectory.
+
+ Returns:
+ True if cancelled, False if not executing
+ """
+ if self._state != TrajectoryState.EXECUTING:
+ return False
+ self._state = TrajectoryState.ABORTED
+ logger.info(f"Trajectory {self._name} cancelled")
+ return True
+
+ def reset(self) -> bool:
+ """Reset to idle state.
+
+ Returns:
+ True if reset, False if currently executing
+ """
+ if self._state == TrajectoryState.EXECUTING:
+ logger.warning(f"Cannot reset {self._name} while executing")
+ return False
+ self._state = TrajectoryState.IDLE
+ self._trajectory = None
+ logger.info(f"Trajectory {self._name} reset to IDLE")
+ return True
+
+ def get_state(self) -> TrajectoryState:
+ """Get current state."""
+ return self._state
+
+ def get_progress(self, t_now: float) -> float:
+ """Get execution progress (0.0 to 1.0).
+
+ Args:
+ t_now: Current orchestrator time
+
+ Returns:
+ Progress as fraction, or 0.0 if not executing
+ """
+ if self._state != TrajectoryState.EXECUTING or self._trajectory is None:
+ return 0.0
+ t_elapsed = t_now - self._start_time
+ return min(1.0, t_elapsed / self._trajectory.duration)
+
+
+__all__ = [
+ "JointTrajectoryTask",
+ "JointTrajectoryTaskConfig",
+]
diff --git a/dimos/control/test_control.py b/dimos/control/test_control.py
new file mode 100644
index 0000000000..2522affa60
--- /dev/null
+++ b/dimos/control/test_control.py
@@ -0,0 +1,542 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the Control Orchestrator module."""
+
+from __future__ import annotations
+
+import threading
+import time
+from unittest.mock import MagicMock
+
+import pytest
+
+from dimos.control.hardware_interface import BackendHardwareInterface
+from dimos.control.task import (
+ ControlMode,
+ JointCommandOutput,
+ JointStateSnapshot,
+ OrchestratorState,
+ ResourceClaim,
+)
+from dimos.control.tasks.trajectory_task import (
+ JointTrajectoryTask,
+ JointTrajectoryTaskConfig,
+ TrajectoryState,
+)
+from dimos.control.tick_loop import TickLoop
+from dimos.hardware.manipulators.spec import ManipulatorBackend
+from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint
+
+# =============================================================================
+# Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def mock_backend():
+ """Create a mock manipulator backend."""
+ backend = MagicMock(spec=ManipulatorBackend)
+ backend.get_dof.return_value = 6
+ backend.read_joint_positions.return_value = [0.0] * 6
+ backend.read_joint_velocities.return_value = [0.0] * 6
+ backend.read_joint_efforts.return_value = [0.0] * 6
+ backend.write_joint_positions.return_value = True
+ backend.write_joint_velocities.return_value = True
+ backend.set_control_mode.return_value = True
+ return backend
+
+
+@pytest.fixture
+def hardware_interface(mock_backend):
+ """Create a BackendHardwareInterface with mock backend."""
+ return BackendHardwareInterface(
+ backend=mock_backend,
+ hardware_id="test_arm",
+ joint_prefix="arm",
+ )
+
+
+@pytest.fixture
+def trajectory_task():
+ """Create a JointTrajectoryTask for testing."""
+ config = JointTrajectoryTaskConfig(
+ joint_names=["arm_joint1", "arm_joint2", "arm_joint3"],
+ priority=10,
+ )
+ return JointTrajectoryTask(name="test_traj", config=config)
+
+
+@pytest.fixture
+def simple_trajectory():
+ """Create a simple 2-point trajectory."""
+ return JointTrajectory(
+ joint_names=["arm_joint1", "arm_joint2", "arm_joint3"],
+ points=[
+ TrajectoryPoint(
+ positions=[0.0, 0.0, 0.0],
+ velocities=[0.0, 0.0, 0.0],
+ time_from_start=0.0,
+ ),
+ TrajectoryPoint(
+ positions=[1.0, 0.5, 0.25],
+ velocities=[0.0, 0.0, 0.0],
+ time_from_start=1.0,
+ ),
+ ],
+ )
+
+
+@pytest.fixture
+def orchestrator_state():
+ """Create a sample OrchestratorState."""
+ joints = JointStateSnapshot(
+ joint_positions={"arm_joint1": 0.0, "arm_joint2": 0.0, "arm_joint3": 0.0},
+ joint_velocities={"arm_joint1": 0.0, "arm_joint2": 0.0, "arm_joint3": 0.0},
+ joint_efforts={"arm_joint1": 0.0, "arm_joint2": 0.0, "arm_joint3": 0.0},
+ timestamp=time.perf_counter(),
+ )
+ return OrchestratorState(joints=joints, t_now=time.perf_counter(), dt=0.01)
+
+
+# =============================================================================
+# Test JointCommandOutput
+# =============================================================================
+
+
+class TestJointCommandOutput:
+ def test_position_output(self):
+ output = JointCommandOutput(
+ joint_names=["j1", "j2"],
+ positions=[0.5, 1.0],
+ mode=ControlMode.POSITION,
+ )
+ assert output.get_values() == [0.5, 1.0]
+ assert output.mode == ControlMode.POSITION
+
+ def test_velocity_output(self):
+ output = JointCommandOutput(
+ joint_names=["j1", "j2"],
+ velocities=[0.1, 0.2],
+ mode=ControlMode.VELOCITY,
+ )
+ assert output.get_values() == [0.1, 0.2]
+ assert output.mode == ControlMode.VELOCITY
+
+ def test_torque_output(self):
+ output = JointCommandOutput(
+ joint_names=["j1", "j2"],
+ efforts=[5.0, 10.0],
+ mode=ControlMode.TORQUE,
+ )
+ assert output.get_values() == [5.0, 10.0]
+ assert output.mode == ControlMode.TORQUE
+
+ def test_no_values_returns_none(self):
+ output = JointCommandOutput(
+ joint_names=["j1"],
+ mode=ControlMode.POSITION,
+ )
+ assert output.get_values() is None
+
+
+# =============================================================================
+# Test JointStateSnapshot
+# =============================================================================
+
+
+class TestJointStateSnapshot:
+ def test_get_position(self):
+ snapshot = JointStateSnapshot(
+ joint_positions={"j1": 0.5, "j2": 1.0},
+ joint_velocities={"j1": 0.0, "j2": 0.1},
+ joint_efforts={"j1": 1.0, "j2": 2.0},
+ timestamp=100.0,
+ )
+ assert snapshot.get_position("j1") == 0.5
+ assert snapshot.get_position("j2") == 1.0
+ assert snapshot.get_position("nonexistent") is None
+
+
+# =============================================================================
+# Test BackendHardwareInterface
+# =============================================================================
+
+
+class TestBackendHardwareInterface:
+ def test_joint_names_prefixed(self, hardware_interface):
+ names = hardware_interface.joint_names
+ assert names == [
+ "arm_joint1",
+ "arm_joint2",
+ "arm_joint3",
+ "arm_joint4",
+ "arm_joint5",
+ "arm_joint6",
+ ]
+
+ def test_read_state(self, hardware_interface):
+ state = hardware_interface.read_state()
+ assert "arm_joint1" in state
+ assert len(state) == 6
+ pos, vel, eff = state["arm_joint1"]
+ assert pos == 0.0
+ assert vel == 0.0
+ assert eff == 0.0
+
+ def test_write_command(self, hardware_interface, mock_backend):
+ commands = {
+ "arm_joint1": 0.5,
+ "arm_joint2": 1.0,
+ }
+ hardware_interface.write_command(commands, ControlMode.POSITION)
+ mock_backend.write_joint_positions.assert_called()
+
+
+# =============================================================================
+# Test JointTrajectoryTask
+# =============================================================================
+
+
+class TestJointTrajectoryTask:
+ def test_initial_state(self, trajectory_task):
+ assert trajectory_task.name == "test_traj"
+ assert not trajectory_task.is_active()
+ assert trajectory_task.get_state() == TrajectoryState.IDLE
+
+ def test_claim(self, trajectory_task):
+ claim = trajectory_task.claim()
+ assert claim.priority == 10
+ assert "arm_joint1" in claim.joints
+ assert "arm_joint2" in claim.joints
+ assert "arm_joint3" in claim.joints
+
+ def test_execute_trajectory(self, trajectory_task, simple_trajectory):
+ time.perf_counter()
+ result = trajectory_task.execute(simple_trajectory)
+ assert result is True
+ assert trajectory_task.is_active()
+ assert trajectory_task.get_state() == TrajectoryState.EXECUTING
+
+ def test_compute_during_trajectory(
+ self, trajectory_task, simple_trajectory, orchestrator_state
+ ):
+ t_start = time.perf_counter()
+ trajectory_task.execute(simple_trajectory)
+
+ # First compute sets start time (deferred start)
+ state0 = OrchestratorState(
+ joints=orchestrator_state.joints,
+ t_now=t_start,
+ dt=0.01,
+ )
+ trajectory_task.compute(state0)
+
+ # Compute at 0.5s into trajectory
+ state = OrchestratorState(
+ joints=orchestrator_state.joints,
+ t_now=t_start + 0.5,
+ dt=0.01,
+ )
+ output = trajectory_task.compute(state)
+
+ assert output is not None
+ assert output.mode == ControlMode.SERVO_POSITION
+ assert len(output.positions) == 3
+ assert 0.4 < output.positions[0] < 0.6
+
+ def test_trajectory_completes(self, trajectory_task, simple_trajectory, orchestrator_state):
+ t_start = time.perf_counter()
+ trajectory_task.execute(simple_trajectory)
+
+ # First compute sets start time (deferred start)
+ state0 = OrchestratorState(
+ joints=orchestrator_state.joints,
+ t_now=t_start,
+ dt=0.01,
+ )
+ trajectory_task.compute(state0)
+
+ # Compute past trajectory duration
+ state = OrchestratorState(
+ joints=orchestrator_state.joints,
+ t_now=t_start + 1.5,
+ dt=0.01,
+ )
+ output = trajectory_task.compute(state)
+
+ # On completion, returns final position (not None) to hold at goal
+ assert output is not None
+ assert output.positions == [1.0, 0.5, 0.25] # Final trajectory point
+ assert not trajectory_task.is_active()
+ assert trajectory_task.get_state() == TrajectoryState.COMPLETED
+
+ def test_cancel_trajectory(self, trajectory_task, simple_trajectory):
+ trajectory_task.execute(simple_trajectory)
+ assert trajectory_task.is_active()
+
+ trajectory_task.cancel()
+ assert not trajectory_task.is_active()
+ assert trajectory_task.get_state() == TrajectoryState.ABORTED
+
+ def test_preemption(self, trajectory_task, simple_trajectory):
+ trajectory_task.execute(simple_trajectory)
+
+ trajectory_task.on_preempted("safety_task", frozenset({"arm_joint1"}))
+ assert trajectory_task.get_state() == TrajectoryState.ABORTED
+ assert not trajectory_task.is_active()
+
+ def test_progress(self, trajectory_task, simple_trajectory, orchestrator_state):
+ t_start = time.perf_counter()
+ trajectory_task.execute(simple_trajectory)
+
+ # First compute sets start time (deferred start)
+ state0 = OrchestratorState(
+ joints=orchestrator_state.joints,
+ t_now=t_start,
+ dt=0.01,
+ )
+ trajectory_task.compute(state0)
+
+ assert trajectory_task.get_progress(t_start) == pytest.approx(0.0, abs=0.01)
+ assert trajectory_task.get_progress(t_start + 0.5) == pytest.approx(0.5, abs=0.01)
+ assert trajectory_task.get_progress(t_start + 1.0) == pytest.approx(1.0, abs=0.01)
+
+
+# =============================================================================
+# Test Arbitration Logic
+# =============================================================================
+
+
+class TestArbitration:
+ def test_single_task_wins(self):
+ outputs = [
+ (
+ MagicMock(name="task1"),
+ ResourceClaim(joints=frozenset({"j1"}), priority=10),
+ JointCommandOutput(joint_names=["j1"], positions=[0.5], mode=ControlMode.POSITION),
+ ),
+ ]
+
+ winners = {}
+ for task, claim, output in outputs:
+ if output is None:
+ continue
+ values = output.get_values()
+ if values is None:
+ continue
+ for i, joint in enumerate(output.joint_names):
+ if joint not in winners:
+ winners[joint] = (claim.priority, values[i], output.mode, task.name)
+
+ assert "j1" in winners
+ assert winners["j1"][1] == 0.5
+
+ def test_higher_priority_wins(self):
+ task_low = MagicMock()
+ task_low.name = "low_priority"
+ task_high = MagicMock()
+ task_high.name = "high_priority"
+
+ outputs = [
+ (
+ task_low,
+ ResourceClaim(joints=frozenset({"j1"}), priority=10),
+ JointCommandOutput(joint_names=["j1"], positions=[0.5], mode=ControlMode.POSITION),
+ ),
+ (
+ task_high,
+ ResourceClaim(joints=frozenset({"j1"}), priority=100),
+ JointCommandOutput(joint_names=["j1"], positions=[0.0], mode=ControlMode.POSITION),
+ ),
+ ]
+
+ winners = {}
+ for task, claim, output in outputs:
+ if output is None:
+ continue
+ values = output.get_values()
+ if values is None:
+ continue
+ for i, joint in enumerate(output.joint_names):
+ if joint not in winners:
+ winners[joint] = (claim.priority, values[i], output.mode, task.name)
+ elif claim.priority > winners[joint][0]:
+ winners[joint] = (claim.priority, values[i], output.mode, task.name)
+
+ assert winners["j1"][3] == "high_priority"
+ assert winners["j1"][1] == 0.0
+
+ def test_non_overlapping_joints(self):
+ task1 = MagicMock()
+ task1.name = "task1"
+ task2 = MagicMock()
+ task2.name = "task2"
+
+ outputs = [
+ (
+ task1,
+ ResourceClaim(joints=frozenset({"j1", "j2"}), priority=10),
+ JointCommandOutput(
+ joint_names=["j1", "j2"],
+ positions=[0.5, 0.6],
+ mode=ControlMode.POSITION,
+ ),
+ ),
+ (
+ task2,
+ ResourceClaim(joints=frozenset({"j3", "j4"}), priority=10),
+ JointCommandOutput(
+ joint_names=["j3", "j4"],
+ positions=[0.7, 0.8],
+ mode=ControlMode.POSITION,
+ ),
+ ),
+ ]
+
+ winners = {}
+ for task, claim, output in outputs:
+ if output is None:
+ continue
+ values = output.get_values()
+ if values is None:
+ continue
+ for i, joint in enumerate(output.joint_names):
+ if joint not in winners:
+ winners[joint] = (claim.priority, values[i], output.mode, task.name)
+
+ assert winners["j1"][3] == "task1"
+ assert winners["j2"][3] == "task1"
+ assert winners["j3"][3] == "task2"
+ assert winners["j4"][3] == "task2"
+
+
+# =============================================================================
+# Test TickLoop
+# =============================================================================
+
+
+class TestTickLoop:
+ def test_tick_loop_starts_and_stops(self, mock_backend):
+ hw = BackendHardwareInterface(mock_backend, "arm", "arm")
+ hardware = {"arm": hw}
+ tasks: dict = {}
+ joint_to_hardware = {f"arm_joint{i + 1}": "arm" for i in range(6)}
+
+ tick_loop = TickLoop(
+ tick_rate=100.0,
+ hardware=hardware,
+ hardware_lock=threading.Lock(),
+ tasks=tasks,
+ task_lock=threading.Lock(),
+ joint_to_hardware=joint_to_hardware,
+ )
+
+ tick_loop.start()
+ time.sleep(0.05)
+ assert tick_loop.tick_count > 0
+
+ tick_loop.stop()
+ final_count = tick_loop.tick_count
+ time.sleep(0.02)
+ assert tick_loop.tick_count == final_count
+
+ def test_tick_loop_calls_compute(self, mock_backend):
+ hw = BackendHardwareInterface(mock_backend, "arm", "arm")
+ hardware = {"arm": hw}
+
+ mock_task = MagicMock()
+ mock_task.name = "test_task"
+ mock_task.is_active.return_value = True
+ mock_task.claim.return_value = ResourceClaim(
+ joints=frozenset({"arm_joint1"}),
+ priority=10,
+ )
+ mock_task.compute.return_value = JointCommandOutput(
+ joint_names=["arm_joint1"],
+ positions=[0.5],
+ mode=ControlMode.POSITION,
+ )
+
+ tasks = {"test_task": mock_task}
+ joint_to_hardware = {f"arm_joint{i + 1}": "arm" for i in range(6)}
+
+ tick_loop = TickLoop(
+ tick_rate=100.0,
+ hardware=hardware,
+ hardware_lock=threading.Lock(),
+ tasks=tasks,
+ task_lock=threading.Lock(),
+ joint_to_hardware=joint_to_hardware,
+ )
+
+ tick_loop.start()
+ time.sleep(0.05)
+ tick_loop.stop()
+
+ assert mock_task.compute.call_count > 0
+
+
+# =============================================================================
+# Integration Test
+# =============================================================================
+
+
+class TestIntegration:
+ def test_full_trajectory_execution(self, mock_backend):
+ hw = BackendHardwareInterface(mock_backend, "arm", "arm")
+ hardware = {"arm": hw}
+
+ config = JointTrajectoryTaskConfig(
+ joint_names=[f"arm_joint{i + 1}" for i in range(6)],
+ priority=10,
+ )
+ traj_task = JointTrajectoryTask(name="traj_arm", config=config)
+ tasks = {"traj_arm": traj_task}
+
+ joint_to_hardware = {f"arm_joint{i + 1}": "arm" for i in range(6)}
+
+ tick_loop = TickLoop(
+ tick_rate=100.0,
+ hardware=hardware,
+ hardware_lock=threading.Lock(),
+ tasks=tasks,
+ task_lock=threading.Lock(),
+ joint_to_hardware=joint_to_hardware,
+ )
+
+ trajectory = JointTrajectory(
+ joint_names=[f"arm_joint{i + 1}" for i in range(6)],
+ points=[
+ TrajectoryPoint(
+ positions=[0.0] * 6,
+ velocities=[0.0] * 6,
+ time_from_start=0.0,
+ ),
+ TrajectoryPoint(
+ positions=[0.5] * 6,
+ velocities=[0.0] * 6,
+ time_from_start=0.5,
+ ),
+ ],
+ )
+
+ tick_loop.start()
+ traj_task.execute(trajectory)
+
+ time.sleep(0.6)
+ tick_loop.stop()
+
+ assert traj_task.get_state() == TrajectoryState.COMPLETED
+ assert mock_backend.write_joint_positions.call_count > 0
diff --git a/dimos/control/tick_loop.py b/dimos/control/tick_loop.py
new file mode 100644
index 0000000000..03e4e0ebd0
--- /dev/null
+++ b/dimos/control/tick_loop.py
@@ -0,0 +1,399 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tick loop for the ControlOrchestrator.
+
+This module contains the core control loop logic:
+- Read state from all hardware
+- Compute outputs from all active tasks
+- Arbitrate conflicts per-joint (highest priority wins)
+- Route commands to hardware
+- Publish aggregated joint state
+
+Separated from orchestrator.py following the DimOS pattern of
+splitting coordination logic from module wrapper.
+"""
+
+from __future__ import annotations
+
+import threading
+import time
+from typing import TYPE_CHECKING, NamedTuple
+
+from dimos.control.task import (
+ ControlTask,
+ JointCommandOutput,
+ JointStateSnapshot,
+ OrchestratorState,
+ ResourceClaim,
+)
+from dimos.msgs.sensor_msgs import JointState
+from dimos.utils.logging_config import setup_logger
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from dimos.control.hardware_interface import HardwareInterface
+ from dimos.hardware.manipulators.spec import ControlMode
+
+logger = setup_logger()
+
+
+class JointWinner(NamedTuple):
+ """Tracks the winning task for a joint during arbitration."""
+
+ priority: int
+ value: float
+ mode: ControlMode
+ task_name: str
+
+
+class TickLoop:
+ """Core tick loop for the control orchestrator.
+
+ Runs the deterministic control cycle:
+ 1. READ: Collect joint state from all hardware
+ 2. COMPUTE: Run all active tasks
+ 3. ARBITRATE: Per-joint conflict resolution (highest priority wins)
+ 4. NOTIFY: Send preemption notifications to affected tasks
+ 5. ROUTE: Convert joint-centric commands to hardware-centric
+ 6. WRITE: Send commands to hardware
+ 7. PUBLISH: Output aggregated JointState
+
+ Args:
+ tick_rate: Control loop frequency in Hz
+ hardware: Dict of hardware_id -> HardwareInterface
+ hardware_lock: Lock protecting hardware dict
+ tasks: Dict of task_name -> ControlTask
+ task_lock: Lock protecting tasks dict
+ joint_to_hardware: Dict mapping joint_name -> hardware_id
+ publish_callback: Optional callback to publish JointState
+ frame_id: Frame ID for published JointState
+ log_ticks: Whether to log tick information
+ """
+
+ def __init__(
+ self,
+ tick_rate: float,
+ hardware: dict[str, HardwareInterface],
+ hardware_lock: threading.Lock,
+ tasks: dict[str, ControlTask],
+ task_lock: threading.Lock,
+ joint_to_hardware: dict[str, str],
+ publish_callback: Callable[[JointState], None] | None = None,
+ frame_id: str = "orchestrator",
+ log_ticks: bool = False,
+ ) -> None:
+ self._tick_rate = tick_rate
+ self._hardware = hardware
+ self._hardware_lock = hardware_lock
+ self._tasks = tasks
+ self._task_lock = task_lock
+ self._joint_to_hardware = joint_to_hardware
+ self._publish_callback = publish_callback
+ self._frame_id = frame_id
+ self._log_ticks = log_ticks
+
+ self._stop_event = threading.Event()
+ self._stop_event.set() # Initially stopped
+ self._tick_thread: threading.Thread | None = None
+ self._last_tick_time: float = 0.0
+ self._tick_count: int = 0
+
+ @property
+ def tick_count(self) -> int:
+ """Number of ticks since start."""
+ return self._tick_count
+
+ @property
+ def is_running(self) -> bool:
+ """Whether the tick loop is currently running."""
+ return not self._stop_event.is_set()
+
+ def start(self) -> None:
+ """Start the tick loop in a daemon thread."""
+ if not self._stop_event.is_set():
+ logger.warning("TickLoop already running")
+ return
+
+ self._stop_event.clear()
+ self._last_tick_time = time.perf_counter()
+ self._tick_count = 0
+
+ self._tick_thread = threading.Thread(
+ target=self._loop,
+ name="ControlOrchestrator-Tick",
+ daemon=True,
+ )
+ self._tick_thread.start()
+ logger.info(f"TickLoop started at {self._tick_rate}Hz")
+
+ def stop(self) -> None:
+ """Stop the tick loop."""
+ self._stop_event.set()
+ if self._tick_thread and self._tick_thread.is_alive():
+ self._tick_thread.join(timeout=2.0)
+ logger.info("TickLoop stopped")
+
+ def _loop(self) -> None:
+ """Main control loop - deterministic read → compute → arbitrate → write."""
+ period = 1.0 / self._tick_rate
+
+ while not self._stop_event.is_set():
+ tick_start = time.perf_counter()
+
+ try:
+ self._tick()
+ except Exception as e:
+ logger.error(f"TickLoop tick error: {e}")
+
+ # Rate control - recalculate sleep time to account for overhead
+ next_tick_time = tick_start + period
+ sleep_time = next_tick_time - time.perf_counter()
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+
+ def _tick(self) -> None:
+ """Single tick: read → compute → arbitrate → route → write."""
+ t_now = time.perf_counter()
+ dt = t_now - self._last_tick_time
+ self._last_tick_time = t_now
+ self._tick_count += 1
+
+ # === PHASE 1: READ ALL HARDWARE ===
+ joint_states = self._read_all_hardware()
+ state = OrchestratorState(joints=joint_states, t_now=t_now, dt=dt)
+
+ # === PHASE 2: COMPUTE ALL ACTIVE TASKS ===
+ commands = self._compute_all_tasks(state)
+
+ # === PHASE 3: ARBITRATE (with mode validation) ===
+ joint_commands, preemptions = self._arbitrate(commands)
+
+ # === PHASE 4: NOTIFY PREEMPTIONS (once per task) ===
+ self._notify_preemptions(preemptions)
+
+ # === PHASE 5: ROUTE TO HARDWARE ===
+ hw_commands = self._route_to_hardware(joint_commands)
+
+ # === PHASE 6: WRITE TO HARDWARE ===
+ self._write_all_hardware(hw_commands)
+
+ # === PHASE 7: PUBLISH AGGREGATED STATE ===
+ if self._publish_callback:
+ self._publish_joint_state(joint_states)
+
+ # Optional logging
+ if self._log_ticks:
+ active = len([c for c in commands if c[2] is not None])
+ logger.debug(
+ f"Tick {self._tick_count}: dt={dt:.4f}s, "
+ f"{len(joint_states.joint_positions)} joints, "
+ f"{active} active tasks"
+ )
+
+ def _read_all_hardware(self) -> JointStateSnapshot:
+ """Read state from all hardware interfaces."""
+ joint_positions: dict[str, float] = {}
+ joint_velocities: dict[str, float] = {}
+ joint_efforts: dict[str, float] = {}
+
+ with self._hardware_lock:
+ for hw in self._hardware.values():
+ try:
+ state = hw.read_state()
+ for joint_name, (pos, vel, eff) in state.items():
+ joint_positions[joint_name] = pos
+ joint_velocities[joint_name] = vel
+ joint_efforts[joint_name] = eff
+ except Exception as e:
+ logger.error(f"Failed to read {hw.hardware_id}: {e}")
+
+ return JointStateSnapshot(
+ joint_positions=joint_positions,
+ joint_velocities=joint_velocities,
+ joint_efforts=joint_efforts,
+ timestamp=time.time(),
+ )
+
+ def _compute_all_tasks(
+ self, state: OrchestratorState
+ ) -> list[tuple[ControlTask, ResourceClaim, JointCommandOutput | None]]:
+ """Compute outputs from all active tasks."""
+ results: list[tuple[ControlTask, ResourceClaim, JointCommandOutput | None]] = []
+
+ with self._task_lock:
+ for task in self._tasks.values():
+ if not task.is_active():
+ continue
+
+ try:
+ claim = task.claim()
+ output = task.compute(state)
+ results.append((task, claim, output))
+ except Exception as e:
+ logger.error(f"Task {task.name} compute error: {e}")
+
+ return results
+
+ def _arbitrate(
+ self,
+ commands: list[tuple[ControlTask, ResourceClaim, JointCommandOutput | None]],
+ ) -> tuple[
+ dict[str, tuple[float, ControlMode, str]],
+ dict[str, dict[str, str]],
+ ]:
+ """Per-joint arbitration with mode conflict detection.
+
+ Returns:
+ Tuple of:
+ - joint_commands: {joint_name: (value, mode, task_name)}
+ - preemptions: {preempted_task: {joint: winning_task}}
+ """
+ winners: dict[str, JointWinner] = {} # joint_name -> current winner
+ preemptions: dict[str, dict[str, str]] = {} # loser_task -> {joint: winner_task}
+
+ for task, claim, output in commands:
+ if output is None:
+ continue
+
+ values = output.get_values()
+ if values is None:
+ continue
+
+ for i, joint_name in enumerate(output.joint_names):
+ candidate = JointWinner(claim.priority, values[i], output.mode, task.name)
+
+ # First claim on this joint
+ if joint_name not in winners:
+ winners[joint_name] = candidate
+ continue
+
+ current = winners[joint_name]
+
+ # Lower priority loses - notify preemption
+ if candidate.priority < current.priority:
+ preemptions.setdefault(task.name, {})[joint_name] = current.task_name
+ continue
+
+ # Higher priority - take over
+ if candidate.priority > current.priority:
+ preemptions.setdefault(current.task_name, {})[joint_name] = task.name
+ winners[joint_name] = candidate
+ continue
+
+ # Same priority - check for mode conflict
+ if candidate.mode != current.mode:
+ logger.warning(
+ f"Mode conflict on {joint_name}: {task.name} wants "
+ f"{candidate.mode.name}, but {current.task_name} wants "
+ f"{current.mode.name}. Dropping {task.name}."
+ )
+ preemptions.setdefault(task.name, {})[joint_name] = current.task_name
+ # Same priority + same mode: first wins (keep current)
+
+ # Convert to output format: joint -> (value, mode, task_name)
+ joint_commands = {joint: (w.value, w.mode, w.task_name) for joint, w in winners.items()}
+
+ return joint_commands, preemptions
+
+ def _notify_preemptions(self, preemptions: dict[str, dict[str, str]]) -> None:
+ """Notify each preempted task with affected joints, grouped by winning task."""
+ with self._task_lock:
+ for task_name, joint_winners in preemptions.items():
+ task = self._tasks.get(task_name)
+ if not task:
+ continue
+
+ # Group joints by winning task
+ by_winner: dict[str, set[str]] = {}
+ for joint, winner in joint_winners.items():
+ if winner not in by_winner:
+ by_winner[winner] = set()
+ by_winner[winner].add(joint)
+
+ # Notify once per distinct winning task
+ for winner, joints in by_winner.items():
+ try:
+ task.on_preempted(
+ by_task=winner,
+ joints=frozenset(joints),
+ )
+ except Exception as e:
+ logger.error(f"Error notifying {task_name} of preemption: {e}")
+
+ def _route_to_hardware(
+ self,
+ joint_commands: dict[str, tuple[float, ControlMode, str]],
+ ) -> dict[str, tuple[dict[str, float], ControlMode]]:
+ """Route joint-centric commands to hardware.
+
+ Returns:
+ {hardware_id: ({joint: value}, mode)}
+ """
+ hw_commands: dict[str, tuple[dict[str, float], ControlMode]] = {}
+
+ with self._hardware_lock:
+ for joint_name, (value, mode, _) in joint_commands.items():
+ hw_id = self._joint_to_hardware.get(joint_name)
+ if hw_id is None:
+ logger.warning(f"Unknown joint {joint_name}, cannot route")
+ continue
+
+ if hw_id not in hw_commands:
+ hw_commands[hw_id] = ({}, mode)
+ else:
+ # Check for mode conflict across joints on same hardware
+ existing_mode = hw_commands[hw_id][1]
+ if mode != existing_mode:
+ logger.error(
+ f"Mode conflict for hardware {hw_id}: joint {joint_name} wants "
+ f"{mode.name} but hardware already has {existing_mode.name}. "
+ f"Dropping command for {joint_name}."
+ )
+ continue
+
+ hw_commands[hw_id][0][joint_name] = value
+
+ return hw_commands
+
+ def _write_all_hardware(
+ self,
+ hw_commands: dict[str, tuple[dict[str, float], ControlMode]],
+ ) -> None:
+ """Write commands to all hardware interfaces."""
+ with self._hardware_lock:
+ for hw_id, (positions, mode) in hw_commands.items():
+ if hw_id in self._hardware:
+ try:
+ self._hardware[hw_id].write_command(positions, mode)
+ except Exception as e:
+ logger.error(f"Failed to write to {hw_id}: {e}")
+
+ def _publish_joint_state(self, snapshot: JointStateSnapshot) -> None:
+ """Publish aggregated JointState for external consumers."""
+ names = list(snapshot.joint_positions.keys())
+ msg = JointState(
+ ts=snapshot.timestamp,
+ frame_id=self._frame_id,
+ name=names,
+ position=[snapshot.joint_positions[n] for n in names],
+ velocity=[snapshot.joint_velocities.get(n, 0.0) for n in names],
+ effort=[snapshot.joint_efforts.get(n, 0.0) for n in names],
+ )
+ if self._publish_callback:
+ self._publish_callback(msg)
+
+
+__all__ = ["TickLoop"]
diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py
index 25d4f7a6e5..b56fe74f4f 100644
--- a/dimos/core/__init__.py
+++ b/dimos/core/__init__.py
@@ -177,7 +177,7 @@ def close_all() -> None:
from dimos.protocol.pubsub import shmpubsub
for obj in gc.get_objects():
- if isinstance(obj, shmpubsub.SharedMemory | shmpubsub.PickleSharedMemory):
+ if isinstance(obj, shmpubsub.SharedMemoryPubSubBase):
try:
obj.stop()
except Exception:
diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py
index bfb553a45d..205c38c361 100644
--- a/dimos/core/global_config.py
+++ b/dimos/core/global_config.py
@@ -33,7 +33,7 @@ class GlobalConfig(BaseSettings):
replay: bool = False
rerun_enabled: bool = True
rerun_server_addr: str | None = None
- viewer_backend: ViewerBackend = "rerun-native"
+ viewer_backend: ViewerBackend = "rerun-web"
n_dask_workers: int = 2
memory_limit: str = "auto"
mujoco_camera_position: str | None = None
diff --git a/dimos/core/resource.py b/dimos/core/resource.py
index 21cdec6322..ce3f735329 100644
--- a/dimos/core/resource.py
+++ b/dimos/core/resource.py
@@ -21,3 +21,25 @@ def start(self) -> None: ...
@abstractmethod
def stop(self) -> None: ...
+
+ def dispose(self) -> None:
+ """
+ Makes a Resource disposable
+ So you can do a
+
+ from reactivex.disposable import CompositeDisposable
+
+ disposables = CompositeDisposable()
+
+ transport1 = LCMTransport(...)
+ transport2 = LCMTransport(...)
+
+ disposables.add(transport1)
+ disposables.add(transport2)
+
+ ...
+
+ disposables.dispose()
+
+ """
+ self.stop()
diff --git a/dimos/core/stream.py b/dimos/core/stream.py
index 66d8cf4ef5..64a1e0edce 100644
--- a/dimos/core/stream.py
+++ b/dimos/core/stream.py
@@ -86,7 +86,9 @@ def broadcast(self, selfstream: Out[T], value: T) -> None:
raise NotImplementedError
# used by local Input
- def subscribe(self, callback: Callable[[T], Any], selfstream: Stream[T]) -> Callable[[], None]:
+ def subscribe(
+ self, callback: Callable[[T], Any], selfstream: Stream[T] | None = None
+ ) -> Callable[[], None]:
raise NotImplementedError
def publish(self, msg: T) -> None:
diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py
index 7a99a23abe..a8b9354f70 100644
--- a/dimos/core/test_blueprints.py
+++ b/dimos/core/test_blueprints.py
@@ -162,6 +162,7 @@ def test_global_config() -> None:
assert blueprint_set.global_config_overrides["option2"] == 42
+@pytest.mark.integration
def test_build_happy_path() -> None:
pubsub.lcm.autoconf()
@@ -272,6 +273,7 @@ class Module3(Module):
blueprint_set_remapped._verify_no_name_conflicts()
+@pytest.mark.integration
def test_remapping() -> None:
"""Test that remapping connections works correctly."""
pubsub.lcm.autoconf()
@@ -351,6 +353,7 @@ def test_future_annotations_support() -> None:
)
+@pytest.mark.integration
def test_future_annotations_autoconnect() -> None:
"""Test that autoconnect works with modules using `from __future__ import annotations`."""
diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py
index 597b580c5c..fde1cf5df2 100644
--- a/dimos/core/test_core.py
+++ b/dimos/core/test_core.py
@@ -28,7 +28,7 @@
)
from dimos.core.testing import MockRobotClient, dimos
from dimos.msgs.geometry_msgs import Vector3
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
from dimos.robot.unitree_webrtc.type.odometry import Odometry
assert dimos
@@ -36,7 +36,7 @@
class Navigation(Module):
mov: Out[Vector3]
- lidar: In[LidarMessage]
+ lidar: In[PointCloud2]
target_position: In[Vector3]
odometry: In[Odometry]
@@ -113,7 +113,7 @@ def test_basic_deployment(dimos) -> None:
nav = dimos.deploy(Navigation)
# this one encodes proper LCM messages
- robot.lidar.transport = LCMTransport("/lidar", LidarMessage)
+ robot.lidar.transport = LCMTransport("/lidar", PointCloud2)
# odometry & mov using just a pickle over LCM
robot.odometry.transport = pLCMTransport("/odom")
diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py
index 4909cd8cc5..b963022c50 100644
--- a/dimos/core/test_stream.py
+++ b/dimos/core/test_stream.py
@@ -24,7 +24,7 @@
rpc,
)
from dimos.core.testing import MockRobotClient, dimos
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
from dimos.robot.unitree_webrtc.type.odometry import Odometry
assert dimos
@@ -157,7 +157,7 @@ def wrapped_unsubscribe() -> None:
def test_subscription(dimos, subscriber_class) -> None:
robot = dimos.deploy(MockRobotClient)
- robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage)
+ robot.lidar.transport = SpyLCMTransport("/lidar", PointCloud2)
robot.odometry.transport = SpyLCMTransport("/odom", Odometry)
subscriber = dimos.deploy(subscriber_class)
@@ -195,7 +195,7 @@ def test_subscription(dimos, subscriber_class) -> None:
def test_get_next(dimos) -> None:
robot = dimos.deploy(MockRobotClient)
- robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage)
+ robot.lidar.transport = SpyLCMTransport("/lidar", PointCloud2)
robot.odometry.transport = SpyLCMTransport("/odom", Odometry)
subscriber = dimos.deploy(RXPYSubscriber)
@@ -224,7 +224,7 @@ def test_get_next(dimos) -> None:
def test_hot_getter(dimos) -> None:
robot = dimos.deploy(MockRobotClient)
- robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage)
+ robot.lidar.transport = SpyLCMTransport("/lidar", PointCloud2)
robot.odometry.transport = SpyLCMTransport("/odom", Odometry)
subscriber = dimos.deploy(RXPYSubscriber)
diff --git a/dimos/core/testing.py b/dimos/core/testing.py
index 832f1f985b..38774ef327 100644
--- a/dimos/core/testing.py
+++ b/dimos/core/testing.py
@@ -19,7 +19,8 @@
from dimos.core import In, Module, Out, rpc, start
from dimos.msgs.geometry_msgs import Vector3
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
+from dimos.robot.unitree_webrtc.type.lidar import pointcloud2_from_webrtc_lidar
from dimos.robot.unitree_webrtc.type.odometry import Odometry
from dimos.utils.testing import SensorReplay
@@ -34,7 +35,7 @@ def dimos(): # type: ignore[no-untyped-def]
class MockRobotClient(Module):
odometry: Out[Odometry]
- lidar: Out[LidarMessage]
+ lidar: Out[PointCloud2]
mov: In[Vector3]
mov_msg_count = 0
@@ -65,7 +66,7 @@ def stop(self) -> None:
def odomloop(self) -> None:
odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg)
- lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg)
+ lidardata = SensorReplay("office_lidar", autocast=pointcloud2_from_webrtc_lidar)
lidariter = lidardata.iterate()
self._stop_event.clear()
diff --git a/dimos/core/transport.py b/dimos/core/transport.py
index 8ffbfc91f4..4c1b19ee2e 100644
--- a/dimos/core/transport.py
+++ b/dimos/core/transport.py
@@ -26,9 +26,11 @@
)
from dimos.core.stream import In, Out, Stream, Transport
+from dimos.msgs.protocol import DimosMsg
from dimos.protocol.pubsub.jpeg_shm import JpegSharedMemory
from dimos.protocol.pubsub.lcmpubsub import LCM, JpegLCM, PickleLCM, Topic as LCMTopic
-from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory
+from dimos.protocol.pubsub.rospubsub import DimosROS, ROSTopic
+from dimos.protocol.pubsub.shmpubsub import BytesSharedMemory, PickleSharedMemory
if TYPE_CHECKING:
from collections.abc import Callable
@@ -62,8 +64,7 @@ def __reduce__(self): # type: ignore[no-untyped-def]
def broadcast(self, _: Out[T] | None, msg: T) -> None:
if not self._started:
- self.lcm.start()
- self._started = True
+ self.start()
self.lcm.publish(self.topic, msg)
@@ -71,14 +72,16 @@ def subscribe(
self, callback: Callable[[T], Any], selfstream: Stream[T] | None = None
) -> Callable[[], None]:
if not self._started:
- self.lcm.start()
- self._started = True
+ self.start()
return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg))
- def start(self) -> None: ...
+ def start(self) -> None:
+ self.lcm.start()
+ self._started = True
def stop(self) -> None:
self.lcm.stop()
+ self._started = False
class LCMTransport(PubSubTransport[T]):
@@ -89,25 +92,26 @@ def __init__(self, topic: str, type: type, **kwargs) -> None: # type: ignore[no
if not hasattr(self, "lcm"):
self.lcm = LCM(**kwargs)
- def start(self) -> None: ...
+ def start(self) -> None:
+ self.lcm.start()
+ self._started = True
def stop(self) -> None:
self.lcm.stop()
+ self._started = False
def __reduce__(self): # type: ignore[no-untyped-def]
return (LCMTransport, (self.topic.topic, self.topic.lcm_type))
def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def]
if not self._started:
- self.lcm.start()
- self._started = True
+ self.start()
self.lcm.publish(self.topic, msg)
def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override]
if not self._started:
- self.lcm.start()
- self._started = True
+ self.start()
return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value]
@@ -119,10 +123,13 @@ def __init__(self, topic: str, type: type, **kwargs) -> None: # type: ignore[no
def __reduce__(self): # type: ignore[no-untyped-def]
return (JpegLcmTransport, (self.topic.topic, self.topic.lcm_type))
- def start(self) -> None: ...
+ def start(self) -> None:
+ self.lcm.start()
+ self._started = True
def stop(self) -> None:
self.lcm.stop()
+ self._started = False
class pSHMTransport(PubSubTransport[T]):
@@ -137,21 +144,22 @@ def __reduce__(self): # type: ignore[no-untyped-def]
def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def]
if not self._started:
- self.shm.start()
- self._started = True
+ self.start()
self.shm.publish(self.topic, msg)
def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override]
if not self._started:
- self.shm.start()
- self._started = True
+ self.start()
return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value]
- def start(self) -> None: ...
+ def start(self) -> None:
+ self.shm.start()
+ self._started = True
def stop(self) -> None:
self.shm.stop()
+ self._started = False
class SHMTransport(PubSubTransport[T]):
@@ -159,28 +167,29 @@ class SHMTransport(PubSubTransport[T]):
def __init__(self, topic: str, **kwargs) -> None: # type: ignore[no-untyped-def]
super().__init__(topic)
- self.shm = SharedMemory(**kwargs)
+ self.shm = BytesSharedMemory(**kwargs)
def __reduce__(self): # type: ignore[no-untyped-def]
return (SHMTransport, (self.topic,))
def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def]
if not self._started:
- self.shm.start()
- self._started = True
+ self.start()
self.shm.publish(self.topic, msg)
def subscribe(self, callback: Callable[[T], None], selfstream: In[T] | None = None) -> None: # type: ignore[override]
if not self._started:
- self.shm.start()
- self._started = True
+ self.start()
return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[arg-type, return-value]
- def start(self) -> None: ...
+ def start(self) -> None:
+ self.shm.start()
+ self._started = True
def stop(self) -> None:
self.shm.stop()
+ self._started = False
class JpegShmTransport(PubSubTransport[T]):
@@ -196,20 +205,57 @@ def __reduce__(self): # type: ignore[no-untyped-def]
def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def]
if not self._started:
- self.shm.start()
- self._started = True
+ self.start()
self.shm.publish(self.topic, msg)
def subscribe(self, callback: Callable[[T], None], selfstream: In[T] | None = None) -> None: # type: ignore[override]
if not self._started:
- self.shm.start()
- self._started = True
+ self.start()
return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[arg-type, return-value]
- def start(self) -> None: ...
+ def start(self) -> None:
+ self.shm.start()
+ self._started = True
- def stop(self) -> None: ...
+ def stop(self) -> None:
+ self.shm.stop()
+ self._started = False
+
+
+class ROSTransport(PubSubTransport[DimosMsg]):
+ _ros: DimosROS | None = None
+
+ def __init__(self, topic: str, msg_type: type[DimosMsg], **kwargs: Any) -> None:
+ super().__init__(ROSTopic(topic, msg_type))
+ self._kwargs = kwargs
+
+ def __reduce__(self) -> tuple[Any, ...]:
+ return (ROSTransport, (self.topic.topic, self.topic.msg_type))
+
+ def broadcast(self, _: Out[DimosMsg], msg: DimosMsg) -> None:
+ if self._ros is None:
+ self.start()
+ assert self._ros is not None # for type narrowing
+ self._ros.publish(self.topic, msg)
+
+ def subscribe(
+ self, callback: Callable[[DimosMsg], Any], selfstream: Stream[DimosMsg] | None = None
+ ) -> Callable[[], None]:
+ if self._ros is None:
+ self.start()
+ assert self._ros is not None # for type narrowing
+ return self._ros.subscribe(self.topic, lambda msg, topic: callback(msg))
+
+ def start(self) -> None:
+ if self._ros is None:
+ self._ros = DimosROS(**self._kwargs)
+ self._ros.start()
+
+ def stop(self) -> None:
+ if self._ros is not None:
+ self._ros.stop()
+ self._ros = None
class ZenohTransport(PubSubTransport[T]): ...
diff --git a/dimos/dashboard/rerun_scene_wiring.py b/dimos/dashboard/rerun_scene_wiring.py
new file mode 100644
index 0000000000..56efe306ea
--- /dev/null
+++ b/dimos/dashboard/rerun_scene_wiring.py
@@ -0,0 +1,152 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Rerun scene wiring helpers (static attachments, URDF, pinholes).
+
+This module is intentionally *not* a TF visualizer.
+It only provides static Rerun scene setup:
+- view coordinates
+- attach semantic entity paths (world/robot/...) under named TF frames (base_link, camera_optical, ...)
+- optional URDF logging
+- optional axes gizmo + camera pinhole(s)
+
+Dynamic TF visualization remains the responsibility of `TFRerunModule`.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Protocol
+
+import rerun as rr
+
+from dimos.core import Module, rpc
+from dimos.core.global_config import GlobalConfig
+from dimos.dashboard.rerun_init import connect_rerun
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+
+class _HasToRerun(Protocol):
+ def to_rerun(self) -> Any: ...
+
+
+def _attach_entity(entity_path: str, parent_frame: str) -> None:
+ """Attach an entity path's implicit frame (tf#/...) under a named frame."""
+ rr.log(
+ entity_path,
+ rr.Transform3D(
+ translation=[0.0, 0.0, 0.0],
+ rotation=rr.Quaternion(xyzw=[0.0, 0.0, 0.0, 1.0]),
+ parent_frame=parent_frame, # type: ignore[call-arg]
+ ),
+ static=True,
+ )
+
+
+class RerunSceneWiringModule(Module):
+ """Static Rerun scene wiring for semantic entity paths."""
+
+ _global_config: GlobalConfig
+
+ # Semantic entity roots
+ world_entity: str
+ robot_entity: str
+ robot_axes_entity: str
+
+ # Named TF frames to attach to
+ world_frame: str
+ robot_frame: str
+
+ # Optional assets
+ urdf_path: str | Path | None
+ axes_size: float | None
+
+ # Multi-camera wiring:
+ # tuple = (camera_entity_path, camera_named_frame, camera_info_static)
+ cameras: Sequence[tuple[str, str, _HasToRerun]]
+ camera_rgb_suffix: str
+
+ def __init__(
+ self,
+ *,
+ global_config: GlobalConfig | None = None,
+ world_entity: str = "world",
+ robot_entity: str = "world/robot",
+ robot_axes_entity: str = "world/robot/axes",
+ world_frame: str = "world",
+ robot_frame: str = "base_link",
+ urdf_path: str | Path | None = None,
+ axes_size: float | None = 0.5,
+ cameras: Sequence[tuple[str, str, _HasToRerun]] = (),
+ camera_rgb_suffix: str = "rgb",
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self._global_config = global_config or GlobalConfig()
+
+ self.world_entity = world_entity
+ self.robot_entity = robot_entity
+ self.robot_axes_entity = robot_axes_entity
+
+ self.world_frame = world_frame
+ self.robot_frame = robot_frame
+
+ self.urdf_path = urdf_path
+ self.axes_size = axes_size
+
+ self.cameras = cameras
+ self.camera_rgb_suffix = camera_rgb_suffix
+
+ @rpc
+ def start(self) -> None:
+ super().start()
+
+ if not self._global_config.viewer_backend.startswith("rerun"):
+ return
+
+ connect_rerun(global_config=self._global_config)
+
+ # Global view coordinates (applies to views at/under this origin).
+ rr.log(self.world_entity, rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)
+
+ # Attach semantic entity paths to named TF frames.
+ _attach_entity(self.world_entity, self.world_frame)
+ _attach_entity(self.robot_entity, self.robot_frame)
+
+ if self.axes_size is not None:
+ rr.log(self.robot_axes_entity, rr.TransformAxes3D(self.axes_size), static=True) # type: ignore[attr-defined]
+
+ # Optional URDF load (purely visual).
+ if self.urdf_path is not None:
+ p = Path(self.urdf_path)
+ if p.exists():
+ rr.log_file_from_path(
+ str(p),
+ entity_path_prefix=self.robot_entity,
+ static=True,
+ )
+
+ # Multi-camera: attach camera entities + log static pinholes.
+ for cam_entity, cam_frame, cam_info in self.cameras:
+ _attach_entity(cam_entity, cam_frame)
+ rr.log(cam_entity, cam_info.to_rerun(), static=True) # type: ignore[no-untyped-call]
+
+ @rpc
+ def stop(self) -> None:
+ super().stop()
+
+
+rerun_scene_wiring = RerunSceneWiringModule.blueprint
diff --git a/dimos/dashboard/tf_rerun_module.py b/dimos/dashboard/tf_rerun_module.py
index c862778cad..bca05ce2e4 100644
--- a/dimos/dashboard/tf_rerun_module.py
+++ b/dimos/dashboard/tf_rerun_module.py
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""TF Rerun Module - Automatically visualize all transforms in Rerun.
+"""TF Rerun Module - Snapshot TF visualization in Rerun.
-This module subscribes to the /tf LCM topic and logs ALL transforms
-to Rerun, providing automatic visualization of the robot's TF tree.
+This module polls the TF buffer at a configurable rate and logs the latest
+transform for each edge to Rerun. This provides stable, rate-limited TF
+visualization without subscribing to the /tf transport from here.
Usage:
# In blueprints:
@@ -29,35 +30,41 @@ def my_robot():
)
"""
-from typing import Any
+from collections.abc import Sequence
+import threading
+import time
+from typing import Any, cast
import rerun as rr
from dimos.core import Module, rpc
+from dimos.core.blueprints import ModuleBlueprintSet, autoconnect
from dimos.core.global_config import GlobalConfig
from dimos.dashboard.rerun_init import connect_rerun
-from dimos.msgs.tf2_msgs import TFMessage
-from dimos.protocol.pubsub.lcmpubsub import LCM, Topic
+from dimos.dashboard.rerun_scene_wiring import rerun_scene_wiring
from dimos.utils.logging_config import setup_logger
logger = setup_logger()
class TFRerunModule(Module):
- """Subscribes to /tf LCM topic and logs all transforms to Rerun.
+ """Polls TF buffer and logs snapshot transforms to Rerun.
This module automatically visualizes the TF tree in Rerun by:
- - Subscribing to the /tf LCM topic (captures ALL transforms in the system)
- - Logging each transform to its derived entity path (world/{child_frame_id})
+ - Using `self.tf` (the system TF service) to maintain the TF buffer
+ - Polling at a configurable rate and logging the latest transform per edge
"""
_global_config: GlobalConfig
- _lcm: LCM | None = None
- _unsubscribe: Any = None
+ _poll_thread: threading.Thread | None = None
+ _stop_event: threading.Event | None = None
+ _poll_hz: float
+ _last_ts_by_edge: dict[tuple[str, str], float]
def __init__(
self,
global_config: GlobalConfig | None = None,
+ poll_hz: float = 30.0,
**kwargs: Any,
) -> None:
"""Initialize TFRerunModule.
@@ -68,6 +75,8 @@ def __init__(
"""
super().__init__(**kwargs)
self._global_config = global_config or GlobalConfig()
+ self._poll_hz = poll_hz
+ self._last_ts_by_edge = {}
@rpc
def start(self) -> None:
@@ -78,35 +87,86 @@ def start(self) -> None:
if self._global_config.viewer_backend.startswith("rerun"):
connect_rerun(global_config=self._global_config)
- # Subscribe directly to LCM /tf topic (captures ALL transforms)
- self._lcm = LCM()
- self._lcm.start()
- topic = Topic("/tf", TFMessage)
- self._unsubscribe = self._lcm.subscribe(topic, self._on_tf_message)
- logger.info("TFRerunModule: subscribed to /tf, logging all transforms to Rerun")
+ # Ensure TF transport is started so its internal subscription populates the buffer.
+ self.tf.start(sub=True)
- def _on_tf_message(self, msg: TFMessage, topic: Topic) -> None:
- """Log all transforms in TFMessage to Rerun.
+ self._stop_event = threading.Event()
+ self._poll_thread = threading.Thread(target=self._poll_loop, daemon=True)
+ self._poll_thread.start()
+ logger.info("TFRerunModule: started TF snapshot polling", poll_hz=self._poll_hz)
- Args:
- msg: TFMessage containing transforms to visualize
- topic: The LCM topic (unused but required by callback signature)
- """
- for entity_path, transform in msg.to_rerun(): # type: ignore[no-untyped-call]
- rr.log(entity_path, transform)
+ def _poll_loop(self) -> None:
+ assert self._stop_event is not None
+ period_s = 1.0 / max(self._poll_hz, 0.1)
+
+ while not self._stop_event.is_set():
+ # Snapshot keys to avoid concurrent modification while TF buffer updates.
+ items = list(self.tf.buffers.items()) # type: ignore[attr-defined]
+ for (parent, child), buffer in items:
+ latest = buffer.get()
+ if latest is None:
+ continue
+ last_ts = self._last_ts_by_edge.get((parent, child))
+ if last_ts is not None and latest.ts == last_ts:
+ continue
+
+ # Log under `world/tf/...` so it is visible under the default 3D view origin.
+ rr.log(f"world/tf/{child}", latest.to_rerun()) # type: ignore[no-untyped-call]
+ self._last_ts_by_edge[(parent, child)] = latest.ts
+
+ time.sleep(period_s)
@rpc
def stop(self) -> None:
"""Stop the TF visualization module and cleanup LCM subscription."""
- if self._unsubscribe:
- self._unsubscribe()
- self._unsubscribe = None
+ if self._stop_event is not None:
+ self._stop_event.set()
+ self._stop_event = None
- if self._lcm:
- self._lcm.stop()
- self._lcm = None
+ if self._poll_thread is not None and self._poll_thread.is_alive():
+ self._poll_thread.join(timeout=1.0)
+ self._poll_thread = None
super().stop()
-tf_rerun = TFRerunModule.blueprint
+def tf_rerun(
+ *,
+ poll_hz: float = 30.0,
+ scene: bool = True,
+ # Scene wiring kwargs (only used if scene=True)
+ world_entity: str = "world",
+ robot_entity: str = "world/robot",
+ robot_axes_entity: str = "world/robot/axes",
+ world_frame: str = "world",
+ robot_frame: str = "base_link",
+ urdf_path: str | None = None,
+ axes_size: float | None = 0.5,
+ cameras: Sequence[tuple[str, str, Any]] = (),
+ camera_rgb_suffix: str = "rgb",
+) -> ModuleBlueprintSet:
+ """Convenience blueprint: TF snapshot polling + (optional) static scene wiring.
+
+ - TF visualization stays in `TFRerunModule` (poll TF buffer, log to `world/tf/*`).
+ - Scene wiring is handled by `RerunSceneWiringModule` (view coords, attachments, URDF, pinholes).
+ """
+ tf_bp = cast("ModuleBlueprintSet", TFRerunModule.blueprint(poll_hz=poll_hz))
+ if not scene:
+ return tf_bp
+
+ scene_bp = cast(
+ "ModuleBlueprintSet",
+ rerun_scene_wiring(
+ world_entity=world_entity,
+ robot_entity=robot_entity,
+ robot_axes_entity=robot_axes_entity,
+ world_frame=world_frame,
+ robot_frame=robot_frame,
+ urdf_path=urdf_path,
+ axes_size=axes_size,
+ cameras=cameras,
+ camera_rgb_suffix=camera_rgb_suffix,
+ ),
+ )
+
+ return autoconnect(tf_bp, scene_bp)
diff --git a/dimos/e2e_tests/conftest.py b/dimos/e2e_tests/conftest.py
index 12d3e407ae..46b92151e9 100644
--- a/dimos/e2e_tests/conftest.py
+++ b/dimos/e2e_tests/conftest.py
@@ -63,8 +63,8 @@ def fun(*, points: list[tuple[float, float, float]], fail_message: str) -> None:
def start_blueprint() -> Iterator[Callable[[str], DimosCliCall]]:
dimos_robot_call = DimosCliCall()
- def set_name_and_start(demo_name: str) -> DimosCliCall:
- dimos_robot_call.demo_name = demo_name
+ def set_name_and_start(*demo_args: str) -> DimosCliCall:
+ dimos_robot_call.demo_args = list(demo_args)
dimos_robot_call.start()
return dimos_robot_call
diff --git a/dimos/e2e_tests/dimos_cli_call.py b/dimos/e2e_tests/dimos_cli_call.py
index 07def58782..2e987cf7ad 100644
--- a/dimos/e2e_tests/dimos_cli_call.py
+++ b/dimos/e2e_tests/dimos_cli_call.py
@@ -19,18 +19,20 @@
class DimosCliCall:
process: subprocess.Popen[bytes] | None
- demo_name: str | None = None
+ demo_args: list[str] | None = None
def __init__(self) -> None:
self.process = None
def start(self) -> None:
- if self.demo_name is None:
- raise ValueError("Demo name must be set before starting the process.")
+ if self.demo_args is None:
+ raise ValueError("Demo args must be set before starting the process.")
- self.process = subprocess.Popen(
- ["dimos", "--simulation", "run", self.demo_name],
- )
+ args = list(self.demo_args)
+ if len(args) == 1:
+ args = ["run", *args]
+
+ self.process = subprocess.Popen(["dimos", "--simulation", *args])
def stop(self) -> None:
if self.process is None:
diff --git a/dimos/e2e_tests/test_control_orchestrator.py b/dimos/e2e_tests/test_control_orchestrator.py
new file mode 100644
index 0000000000..aa820d66ec
--- /dev/null
+++ b/dimos/e2e_tests/test_control_orchestrator.py
@@ -0,0 +1,264 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""End-to-end tests for the ControlOrchestrator.
+
+These tests start a real orchestrator process and communicate via LCM/RPC.
+Unlike unit tests, these verify the full system integration.
+
+Run with:
+ pytest dimos/e2e_tests/test_control_orchestrator.py -v -s
+"""
+
+import os
+import time
+
+import pytest
+
+from dimos.control.orchestrator import ControlOrchestrator
+from dimos.core.rpc_client import RPCClient
+from dimos.msgs.sensor_msgs import JointState
+from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint, TrajectoryState
+
+
+@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM doesn't work in CI.")
+@pytest.mark.e2e
+class TestControlOrchestratorE2E:
+ """End-to-end tests for ControlOrchestrator."""
+
+ def test_orchestrator_starts_and_responds_to_rpc(self, lcm_spy, start_blueprint) -> None:
+ """Test that orchestrator starts and responds to RPC queries."""
+ # Save topics we care about (LCM topics include type suffix)
+ joint_state_topic = "/orchestrator/joint_state#sensor_msgs.JointState"
+ lcm_spy.save_topic(joint_state_topic)
+ lcm_spy.save_topic("/rpc/ControlOrchestrator/list_joints/res")
+ lcm_spy.save_topic("/rpc/ControlOrchestrator/list_tasks/res")
+
+ # Start the mock orchestrator blueprint
+ start_blueprint("orchestrator-mock")
+
+ # Wait for joint state to be published (proves tick loop is running)
+ lcm_spy.wait_for_saved_topic(
+ joint_state_topic,
+ timeout=10.0,
+ )
+
+ # Create RPC client and query
+ client = RPCClient(None, ControlOrchestrator)
+ try:
+ # Test list_joints RPC
+ joints = client.list_joints()
+ assert joints is not None
+ assert len(joints) == 7 # Mock arm has 7 DOF
+ assert "arm_joint1" in joints
+
+ # Test list_tasks RPC
+ tasks = client.list_tasks()
+ assert tasks is not None
+ assert "traj_arm" in tasks
+
+ # Test list_hardware RPC
+ hardware = client.list_hardware()
+ assert hardware is not None
+ assert "arm" in hardware
+ finally:
+ client.stop_rpc_client()
+
+ def test_orchestrator_executes_trajectory(self, lcm_spy, start_blueprint) -> None:
+ """Test that orchestrator executes a trajectory via RPC."""
+ # Save topics
+ lcm_spy.save_topic("/orchestrator/joint_state#sensor_msgs.JointState")
+ lcm_spy.save_topic("/rpc/ControlOrchestrator/execute_trajectory/res")
+ lcm_spy.save_topic("/rpc/ControlOrchestrator/get_trajectory_status/res")
+
+ # Start orchestrator
+ start_blueprint("orchestrator-mock")
+
+ # Wait for it to be ready
+ lcm_spy.wait_for_saved_topic(
+ "/orchestrator/joint_state#sensor_msgs.JointState", timeout=10.0
+ )
+
+ # Create RPC client
+ client = RPCClient(None, ControlOrchestrator)
+ try:
+ # Get initial joint positions
+ initial_positions = client.get_joint_positions()
+ assert initial_positions is not None
+
+ # Create a simple trajectory
+ trajectory = JointTrajectory(
+ joint_names=[f"arm_joint{i + 1}" for i in range(7)],
+ points=[
+ TrajectoryPoint(
+ time_from_start=0.0,
+ positions=[0.0] * 7,
+ velocities=[0.0] * 7,
+ ),
+ TrajectoryPoint(
+ time_from_start=0.5,
+ positions=[0.1] * 7,
+ velocities=[0.0] * 7,
+ ),
+ ],
+ )
+
+ # Execute trajectory
+ result = client.execute_trajectory("traj_arm", trajectory)
+ assert result is True
+
+ # Poll for completion
+ timeout = 5.0
+ start_time = time.time()
+ completed = False
+
+ while time.time() - start_time < timeout:
+ status = client.get_trajectory_status("traj_arm")
+ if status is not None and status.state == TrajectoryState.COMPLETED.name:
+ completed = True
+ break
+ time.sleep(0.1)
+
+ assert completed, "Trajectory did not complete within timeout"
+ finally:
+ client.stop_rpc_client()
+
+ def test_orchestrator_joint_state_published(self, lcm_spy, start_blueprint) -> None:
+ """Test that joint state messages are published at expected rate."""
+ joint_state_topic = "/orchestrator/joint_state#sensor_msgs.JointState"
+ lcm_spy.save_topic(joint_state_topic)
+
+ # Start orchestrator
+ start_blueprint("orchestrator-mock")
+
+ # Wait for initial message
+ lcm_spy.wait_for_saved_topic(joint_state_topic, timeout=10.0)
+
+ # Collect messages for 1 second
+ time.sleep(1.0)
+
+ # Check we received messages (should be ~100 at 100Hz)
+ with lcm_spy._messages_lock:
+ message_count = len(lcm_spy.messages.get(joint_state_topic, []))
+
+ # Allow some tolerance (at least 50 messages in 1 second)
+ assert message_count >= 50, f"Expected ~100 messages, got {message_count}"
+
+ # Decode a message to verify structure
+ with lcm_spy._messages_lock:
+ raw_msg = lcm_spy.messages[joint_state_topic][0]
+
+ joint_state = JointState.lcm_decode(raw_msg)
+ assert len(joint_state.name) == 7
+ assert len(joint_state.position) == 7
+ assert "arm_joint1" in joint_state.name
+
+ def test_orchestrator_cancel_trajectory(self, lcm_spy, start_blueprint) -> None:
+ """Test that a running trajectory can be cancelled."""
+ lcm_spy.save_topic("/orchestrator/joint_state#sensor_msgs.JointState")
+
+ # Start orchestrator
+ start_blueprint("orchestrator-mock")
+ lcm_spy.wait_for_saved_topic(
+ "/orchestrator/joint_state#sensor_msgs.JointState", timeout=10.0
+ )
+
+ client = RPCClient(None, ControlOrchestrator)
+ try:
+ # Create a long trajectory (5 seconds)
+ trajectory = JointTrajectory(
+ joint_names=[f"arm_joint{i + 1}" for i in range(7)],
+ points=[
+ TrajectoryPoint(
+ time_from_start=0.0,
+ positions=[0.0] * 7,
+ velocities=[0.0] * 7,
+ ),
+ TrajectoryPoint(
+ time_from_start=5.0,
+ positions=[1.0] * 7,
+ velocities=[0.0] * 7,
+ ),
+ ],
+ )
+
+ # Start trajectory
+ result = client.execute_trajectory("traj_arm", trajectory)
+ assert result is True
+
+ # Wait a bit then cancel
+ time.sleep(0.5)
+ cancel_result = client.cancel_trajectory("traj_arm")
+ assert cancel_result is True
+
+ # Check status is ABORTED
+ status = client.get_trajectory_status("traj_arm")
+ assert status is not None
+ assert status.state == TrajectoryState.ABORTED.name
+ finally:
+ client.stop_rpc_client()
+
+ def test_dual_arm_orchestrator(self, lcm_spy, start_blueprint) -> None:
+ """Test dual-arm orchestrator with independent trajectories."""
+ lcm_spy.save_topic("/orchestrator/joint_state#sensor_msgs.JointState")
+
+ # Start dual-arm mock orchestrator
+ start_blueprint("orchestrator-dual-mock")
+ lcm_spy.wait_for_saved_topic(
+ "/orchestrator/joint_state#sensor_msgs.JointState", timeout=10.0
+ )
+
+ client = RPCClient(None, ControlOrchestrator)
+ try:
+ # Verify both arms present
+ joints = client.list_joints()
+ assert "left_joint1" in joints
+ assert "right_joint1" in joints
+
+ tasks = client.list_tasks()
+ assert "traj_left" in tasks
+ assert "traj_right" in tasks
+
+ # Create trajectories for both arms
+ left_trajectory = JointTrajectory(
+ joint_names=[f"left_joint{i + 1}" for i in range(7)],
+ points=[
+ TrajectoryPoint(time_from_start=0.0, positions=[0.0] * 7),
+ TrajectoryPoint(time_from_start=0.5, positions=[0.2] * 7),
+ ],
+ )
+
+ right_trajectory = JointTrajectory(
+ joint_names=[f"right_joint{i + 1}" for i in range(6)],
+ points=[
+ TrajectoryPoint(time_from_start=0.0, positions=[0.0] * 6),
+ TrajectoryPoint(time_from_start=0.5, positions=[0.3] * 6),
+ ],
+ )
+
+ # Execute both
+ assert client.execute_trajectory("traj_left", left_trajectory) is True
+ assert client.execute_trajectory("traj_right", right_trajectory) is True
+
+ # Wait for completion
+ time.sleep(1.0)
+
+ # Both should complete
+ left_status = client.get_trajectory_status("traj_left")
+ right_status = client.get_trajectory_status("traj_right")
+
+ assert left_status is not None and left_status.state == TrajectoryState.COMPLETED.name
+ assert right_status is not None and right_status.state == TrajectoryState.COMPLETED.name
+ finally:
+ client.stop_rpc_client()
diff --git a/dimos/e2e_tests/test_dimos_cli_e2e.py b/dimos/e2e_tests/test_dimos_cli_e2e.py
index 7571e113ad..f91db1b2fc 100644
--- a/dimos/e2e_tests/test_dimos_cli_e2e.py
+++ b/dimos/e2e_tests/test_dimos_cli_e2e.py
@@ -19,6 +19,7 @@
@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.")
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.")
+@pytest.mark.e2e
def test_dimos_skills(lcm_spy, start_blueprint, human_input) -> None:
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/set_AgentSpec_register_skills/res")
lcm_spy.save_topic("/rpc/HumanInput/start/res")
@@ -26,7 +27,7 @@ def test_dimos_skills(lcm_spy, start_blueprint, human_input) -> None:
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/sum_numbers/req")
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/sum_numbers/res")
- start_blueprint("demo-skill")
+ start_blueprint("run", "demo-skill")
lcm_spy.wait_for_saved_topic("/rpc/DemoCalculatorSkill/set_AgentSpec_register_skills/res")
lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res")
diff --git a/dimos/e2e_tests/test_person_follow.py b/dimos/e2e_tests/test_person_follow.py
new file mode 100644
index 0000000000..709f4e4511
--- /dev/null
+++ b/dimos/e2e_tests/test_person_follow.py
@@ -0,0 +1,85 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable, Generator
+import os
+import threading
+import time
+
+import pytest
+
+from dimos.e2e_tests.dimos_cli_call import DimosCliCall
+from dimos.e2e_tests.lcm_spy import LcmSpy
+from dimos.simulation.mujoco.person_on_track import PersonTrackPublisher
+
+StartPersonTrack = Callable[[list[tuple[float, float]]], None]
+
+
+@pytest.fixture
+def start_person_track() -> Generator[StartPersonTrack, None, None]:
+ publisher: PersonTrackPublisher | None = None
+ stop_event = threading.Event()
+ thread: threading.Thread | None = None
+
+ def start(track: list[tuple[float, float]]) -> None:
+ nonlocal publisher, thread
+ publisher = PersonTrackPublisher(track)
+
+ def run_person_track() -> None:
+ while not stop_event.is_set():
+ publisher.tick()
+ time.sleep(1 / 60)
+
+ thread = threading.Thread(target=run_person_track, daemon=True)
+ thread.start()
+
+ yield start
+
+ stop_event.set()
+ if thread is not None:
+ thread.join(timeout=1.0)
+ if publisher is not None:
+ publisher.stop()
+
+
+@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.")
+@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.")
+@pytest.mark.mujoco
+def test_person_follow(
+ lcm_spy: LcmSpy,
+ start_blueprint: Callable[[str], DimosCliCall],
+ human_input: Callable[[str], None],
+ start_person_track: StartPersonTrack,
+) -> None:
+ start_blueprint("--mujoco-start-pos", "-6.18 0.96", "run", "unitree-go2-agentic")
+
+ lcm_spy.save_topic("/rpc/HumanInput/start/res")
+ lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res", timeout=120.0)
+ lcm_spy.save_topic("/agent")
+ lcm_spy.wait_for_saved_topic_content("/agent", b"AIMessage", timeout=120.0)
+
+ time.sleep(5)
+
+ start_person_track(
+ [
+ (-2.60, 1.28),
+ (4.80, 0.21),
+ (4.14, -6.0),
+ (0.59, -3.79),
+ (-3.35, -0.51),
+ ]
+ )
+ human_input("follow the person in beige pants")
+
+ lcm_spy.wait_until_odom_position(4.2, -3, threshold=1.5)
diff --git a/dimos/e2e_tests/test_spatial_memory.py b/dimos/e2e_tests/test_spatial_memory.py
index 5029f46525..7c08800a6f 100644
--- a/dimos/e2e_tests/test_spatial_memory.py
+++ b/dimos/e2e_tests/test_spatial_memory.py
@@ -25,14 +25,14 @@
@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.")
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.")
-@pytest.mark.e2e
+@pytest.mark.mujoco
def test_spatial_memory_navigation(
lcm_spy: LcmSpy,
start_blueprint: Callable[[str], DimosCliCall],
human_input: Callable[[str], None],
follow_points: Callable[..., None],
) -> None:
- start_blueprint("unitree-go2-agentic")
+ start_blueprint("run", "unitree-go2-agentic")
lcm_spy.save_topic("/rpc/HumanInput/start/res")
lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res", timeout=120.0)
diff --git a/dimos/hardware/manipulators/README.md b/dimos/hardware/manipulators/README.md
index d4bb1cdba7..d3e54d4cb0 100644
--- a/dimos/hardware/manipulators/README.md
+++ b/dimos/hardware/manipulators/README.md
@@ -1,173 +1,163 @@
# Manipulator Drivers
-Component-based framework for integrating robotic manipulators into DIMOS.
+This module provides manipulator arm drivers using the **B-lite architecture**: Protocol-only with injectable backends.
-## Quick Start: Adding a New Manipulator
-
-Adding support for a new robot arm requires **two files**:
-1. **SDK Wrapper** (~200-500 lines) - Translates vendor SDK to standard interface
-2. **Driver** (~30-50 lines) - Assembles components and configuration
-
-## Directory Structure
+## Architecture Overview
```
-manipulators/
-├── base/ # Framework (don't modify)
-│ ├── sdk_interface.py # BaseManipulatorSDK abstract class
-│ ├── driver.py # BaseManipulatorDriver base class
-│ ├── spec.py # ManipulatorCapabilities dataclass
-│ └── components/ # Reusable standard components
-├── xarm/ # XArm implementation (reference)
-└── piper/ # Piper implementation (reference)
+┌─────────────────────────────────────────────────────────────┐
+│ Driver (Module) │
+│ - Owns threading (control loop, monitor loop) │
+│ - Publishes joint_state, robot_state │
+│ - Subscribes to joint_position_command, joint_velocity_cmd │
+│ - Exposes RPC methods (move_joint, enable_servos, etc.) │
+└─────────────────────┬───────────────────────────────────────┘
+ │ uses
+┌─────────────────────▼───────────────────────────────────────┐
+│ Backend (implements Protocol) │
+│ - Handles SDK communication │
+│ - Unit conversions (radians ↔ vendor units) │
+│ - Swappable: XArmBackend, PiperBackend, MockBackend │
+└─────────────────────────────────────────────────────────────┘
```
-## Hardware Requirements
+## Key Benefits
-Your manipulator **must** support:
+- **Testable**: Inject `MockBackend` for unit tests without hardware
+- **Flexible**: Each arm controls its own threading/timing
+- **Simple**: No ABC inheritance required - just implement the Protocol
+- **Type-safe**: Full type checking via `ManipulatorBackend` Protocol
-| Requirement | Description |
-|-------------|-------------|
-| Joint Position Feedback | Read current joint angles |
-| Joint Position Control | Command target joint positions |
-| Servo Enable/Disable | Enable and disable motor power |
-| Error Reporting | Report error codes/states |
-| Emergency Stop | Hardware or software e-stop |
+## Directory Structure
-**Optional:** velocity control, torque control, cartesian control, F/T sensor, gripper
+```
+manipulators/
+├── spec.py # ManipulatorBackend Protocol + shared types
+├── mock/
+│ └── backend.py # MockBackend for testing
+├── xarm/
+│ ├── backend.py # XArmBackend (SDK wrapper)
+│ ├── arm.py # XArm driver module
+│ └── blueprints.py # Pre-configured blueprints
+└── piper/
+ ├── backend.py # PiperBackend (SDK wrapper)
+ ├── arm.py # Piper driver module
+ └── blueprints.py # Pre-configured blueprints
+```
-## Step 1: Implement SDK Wrapper
+## Quick Start
-Create `your_arm/your_arm_wrapper.py` implementing `BaseManipulatorSDK`:
+### Using a Driver Directly
```python
-from dimos.hardware.manipulators.base.sdk_interface import BaseManipulatorSDK, ManipulatorInfo
-
-class YourArmSDKWrapper(BaseManipulatorSDK):
- def __init__(self):
- self._sdk = None
-
- def connect(self, config: dict) -> bool:
- self._sdk = YourNativeSDK(config['ip'])
- return self._sdk.connect()
-
- def get_joint_positions(self) -> list[float]:
- """Return positions in RADIANS."""
- degrees = self._sdk.get_angles()
- return [math.radians(d) for d in degrees]
+from dimos.hardware.manipulators.xarm import XArm
- def set_joint_positions(self, positions: list[float],
- velocity: float, acceleration: float) -> bool:
- return self._sdk.move_joints(positions, velocity)
-
- def enable_servos(self) -> bool:
- return self._sdk.motor_on()
-
- # ... implement remaining required methods (see sdk_interface.py)
+arm = XArm(ip="192.168.1.185", dof=6)
+arm.start()
+arm.enable_servos()
+arm.move_joint([0, 0, 0, 0, 0, 0])
+arm.stop()
```
-### Unit Conventions
-
-**All SDK wrappers must use these standard units:**
-
-| Quantity | Unit |
-|----------|------|
-| Joint positions | radians |
-| Joint velocities | rad/s |
-| Joint accelerations | rad/s^2 |
-| Joint torques | Nm |
-| Cartesian positions | meters |
-| Forces | N |
-
-## Step 2: Create Driver Assembly
-
-Create `your_arm/your_arm_driver.py`:
+### Using Blueprints
```python
-from dimos.hardware.manipulators.base.driver import BaseManipulatorDriver
-from dimos.hardware.manipulators.base.spec import ManipulatorCapabilities
-from dimos.hardware.manipulators.base.components import (
- StandardMotionComponent,
- StandardServoComponent,
- StandardStatusComponent,
-)
-from .your_arm_wrapper import YourArmSDKWrapper
-
-class YourArmDriver(BaseManipulatorDriver):
- def __init__(self, config: dict):
- sdk = YourArmSDKWrapper()
-
- capabilities = ManipulatorCapabilities(
- dof=6,
- has_gripper=False,
- has_force_torque=False,
- joint_limits_lower=[-3.14, -2.09, -3.14, -3.14, -3.14, -3.14],
- joint_limits_upper=[3.14, 2.09, 3.14, 3.14, 3.14, 3.14],
- max_joint_velocity=[2.0] * 6,
- max_joint_acceleration=[4.0] * 6,
- )
-
- components = [
- StandardMotionComponent(),
- StandardServoComponent(),
- StandardStatusComponent(),
- ]
+from dimos.hardware.manipulators.xarm.blueprints import xarm_trajectory
- super().__init__(sdk, components, config, capabilities)
+coordinator = xarm_trajectory.build()
+coordinator.loop()
```
-## Component API Decorator
-
-Use `@component_api` to expose methods as RPC endpoints:
+### Testing Without Hardware
```python
-from dimos.hardware.manipulators.base.components import component_api
+from dimos.hardware.manipulators.mock import MockBackend
+from dimos.hardware.manipulators.xarm import XArm
-class StandardMotionComponent:
- @component_api
- def move_joint(self, positions: list[float], velocity: float = 1.0):
- """Auto-exposed as driver.move_joint()"""
- ...
+arm = XArm(backend=MockBackend(dof=6))
+arm.start() # No hardware needed!
+arm.move_joint([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
```
-## Threading Architecture
+## Adding a New Arm
-The driver runs **2 threads**:
-1. **Control Loop (100Hz)** - Processes commands, reads joint state, publishes feedback
-2. **Monitor Loop (10Hz)** - Reads robot state, errors, optional sensors
+1. **Create the backend** (`backend.py`):
-```
-RPC Call → Command Queue → Control Loop → SDK → Hardware
- ↓
- SharedState → LCM Publisher
+```python
+class MyArmBackend: # No inheritance needed - just match the Protocol
+ def __init__(self, ip: str = "192.168.1.100", dof: int = 6) -> None:
+ self._ip = ip
+ self._dof = dof
+
+ def connect(self) -> bool: ...
+ def disconnect(self) -> None: ...
+ def read_joint_positions(self) -> list[float]: ...
+ def write_joint_positions(self, positions: list[float], velocity: float = 1.0) -> bool: ...
+ # ... implement other Protocol methods
```
-## Testing Your Driver
+2. **Create the driver** (`arm.py`):
```python
-driver = YourArmDriver({"ip": "192.168.1.100"})
-driver.start()
-driver.enable_servo()
-driver.move_joint([0, 0, 0, 0, 0, 0], velocity=0.5)
-state = driver.get_joint_state()
-driver.stop()
+from dimos.core import Module, ModuleConfig, In, Out, rpc
+from .backend import MyArmBackend
+
+class MyArm(Module[MyArmConfig]):
+ joint_state: Out[JointState]
+ robot_state: Out[RobotState]
+ joint_position_command: In[JointCommand]
+
+ def __init__(self, backend=None, **kwargs):
+ super().__init__(**kwargs)
+ self.backend = backend or MyArmBackend(
+ ip=self.config.ip,
+ dof=self.config.dof,
+ )
+ # ... setup control loops
```
-## Common Issues
+3. **Create blueprints** (`blueprints.py`) for common configurations.
-| Issue | Solution |
-|-------|----------|
-| Unit mismatch | Verify wrapper converts to radians/meters |
-| Commands ignored | Ensure servos are enabled before commanding |
-| Velocity not working | Some arms need mode switch via `set_control_mode()` |
+## ManipulatorBackend Protocol
-## Architecture Details
+All backends must implement these core methods:
-For complete architecture documentation including full SDK interface specification,
-component details, and testing strategies, see:
+| Category | Methods |
+|----------|---------|
+| Connection | `connect()`, `disconnect()`, `is_connected()` |
+| Info | `get_info()`, `get_dof()`, `get_limits()` |
+| State | `read_joint_positions()`, `read_joint_velocities()`, `read_joint_efforts()` |
+| Motion | `write_joint_positions()`, `write_joint_velocities()`, `write_stop()` |
+| Servo | `write_enable()`, `read_enabled()`, `write_clear_errors()` |
+| Mode | `set_control_mode()`, `get_control_mode()` |
-**[component_based_architecture.md](base/component_based_architecture.md)**
+Optional methods (return `None`/`False` if unsupported):
+- `read_cartesian_position()`, `write_cartesian_position()`
+- `read_gripper_position()`, `write_gripper_position()`
+- `read_force_torque()`
-## Reference Implementations
+## Unit Conventions
-- **XArm**: [xarm/xarm_wrapper.py](xarm/xarm_wrapper.py) - Full-featured wrapper
-- **Piper**: [piper/piper_wrapper.py](piper/piper_wrapper.py) - Shows velocity workaround
+All backends convert to/from SI units:
+
+| Quantity | Unit |
+|----------|------|
+| Angles | radians |
+| Angular velocity | rad/s |
+| Torque | Nm |
+| Position | meters |
+| Force | Newtons |
+
+## Available Blueprints
+
+### XArm
+- `xarm_servo` - Basic servo control (6-DOF)
+- `xarm5_servo`, `xarm7_servo` - 5/7-DOF variants
+- `xarm_trajectory` - Driver + trajectory controller
+- `xarm_cartesian` - Driver + cartesian controller
+
+### Piper
+- `piper_servo` - Basic servo control
+- `piper_servo_gripper` - With gripper support
+- `piper_trajectory` - Driver + trajectory controller
+- `piper_left`, `piper_right` - Dual arm configurations
diff --git a/dimos/hardware/manipulators/__init__.py b/dimos/hardware/manipulators/__init__.py
index a54a846afc..e4133dbb51 100644
--- a/dimos/hardware/manipulators/__init__.py
+++ b/dimos/hardware/manipulators/__init__.py
@@ -12,10 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Manipulator Hardware Drivers
+"""Manipulator drivers for robotic arms.
+
+Architecture: B-lite (Protocol-based backends with per-arm drivers)
+
+- spec.py: ManipulatorBackend Protocol and shared types
+- xarm/: XArm driver and backend
+- piper/: Piper driver and backend
+- mock/: Mock backend for testing
-Drivers for various robotic manipulator arms.
+Usage:
+ >>> from dimos.hardware.manipulators.xarm import XArm
+ >>> arm = XArm(ip="192.168.1.185")
+ >>> arm.start()
+ >>> arm.enable_servos()
+ >>> arm.move_joint([0, 0, 0, 0, 0, 0])
+
+Testing:
+ >>> from dimos.hardware.manipulators.xarm import XArm
+ >>> from dimos.hardware.manipulators.mock import MockBackend
+ >>> arm = XArm(backend=MockBackend())
+ >>> arm.start() # No hardware needed!
"""
-__all__ = []
+from dimos.hardware.manipulators.spec import (
+ ControlMode,
+ DriverStatus,
+ JointLimits,
+ ManipulatorBackend,
+ ManipulatorInfo,
+)
+
+__all__ = [
+ "ControlMode",
+ "DriverStatus",
+ "JointLimits",
+ "ManipulatorBackend",
+ "ManipulatorInfo",
+]
diff --git a/dimos/hardware/manipulators/base/__init__.py b/dimos/hardware/manipulators/base/__init__.py
deleted file mode 100644
index 3ed58d9819..0000000000
--- a/dimos/hardware/manipulators/base/__init__.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Base framework for generalized manipulator drivers.
-
-This package provides the foundation for building manipulator drivers
-that work with any robotic arm (XArm, Piper, UR, Franka, etc.).
-"""
-
-from .components import StandardMotionComponent, StandardServoComponent, StandardStatusComponent
-from .driver import BaseManipulatorDriver, Command
-from .sdk_interface import BaseManipulatorSDK, ManipulatorInfo
-from .spec import ManipulatorCapabilities, ManipulatorDriverSpec, RobotState
-from .utils import SharedState
-
-__all__ = [
- # Driver
- "BaseManipulatorDriver",
- # SDK Interface
- "BaseManipulatorSDK",
- "Command",
- "ManipulatorCapabilities",
- # Spec
- "ManipulatorDriverSpec",
- "ManipulatorInfo",
- "RobotState",
- # Utils
- "SharedState",
- # Components
- "StandardMotionComponent",
- "StandardServoComponent",
- "StandardStatusComponent",
-]
diff --git a/dimos/hardware/manipulators/base/component_based_architecture.md b/dimos/hardware/manipulators/base/component_based_architecture.md
deleted file mode 100644
index 893ebf1276..0000000000
--- a/dimos/hardware/manipulators/base/component_based_architecture.md
+++ /dev/null
@@ -1,208 +0,0 @@
-# Component-Based Architecture for Manipulator Drivers
-
-## Overview
-
-This architecture provides maximum code reuse through standardized SDK wrappers and reusable components. Each new manipulator requires only an SDK wrapper (~200-500 lines) and a thin driver assembly (~30-50 lines).
-
-## Architecture Layers
-
-```
-┌─────────────────────────────────────────────────────┐
-│ RPC Interface │
-│ (Standardized across all arms) │
-└─────────────────────────────────────────────────────┘
- ▲
-┌─────────────────────────────────────────────────────┐
-│ Driver Instance (XArmDriver) │
-│ Extends DIMOS Module, assembles components │
-└─────────────────────────────────────────────────────┘
- ▲
-┌─────────────────────────────────────────────────────┐
-│ Standard Components │
-│ (Motion, Servo, Status) - reused everywhere │
-└─────────────────────────────────────────────────────┘
- ▲
-┌─────────────────────────────────────────────────────┐
-│ SDK Wrapper (XArmSDKWrapper) │
-│ Implements BaseManipulatorSDK interface │
-└─────────────────────────────────────────────────────┘
- ▲
-┌─────────────────────────────────────────────────────┐
-│ Native Vendor SDK (XArmAPI) │
-└─────────────────────────────────────────────────────┘
-```
-
-## Core Interfaces
-
-### BaseManipulatorSDK
-
-Abstract interface that all SDK wrappers must implement. See `sdk_interface.py` for full specification.
-
-**Required methods:** `connect()`, `disconnect()`, `is_connected()`, `get_joint_positions()`, `get_joint_velocities()`, `set_joint_positions()`, `enable_servos()`, `disable_servos()`, `emergency_stop()`, `get_error_code()`, `clear_errors()`, `get_info()`
-
-**Optional methods:** `get_force_torque()`, `get_gripper_position()`, `set_cartesian_position()`, etc.
-
-### ManipulatorCapabilities
-
-Dataclass defining arm properties: DOF, joint limits, velocity limits, feature flags.
-
-## Component System
-
-### @component_api Decorator
-
-Methods marked with `@component_api` are automatically exposed as RPC endpoints on the driver:
-
-```python
-from dimos.hardware.manipulators.base.components import component_api
-
-class StandardMotionComponent:
- @component_api
- def move_joint(self, positions: list[float], velocity: float = 1.0) -> dict:
- """Auto-exposed as driver.move_joint()"""
- ...
-```
-
-### Dependency Injection
-
-Components receive dependencies via setter methods, not constructor:
-
-```python
-class StandardMotionComponent:
- def __init__(self):
- self.sdk = None
- self.shared_state = None
- self.command_queue = None
- self.capabilities = None
-
- def set_sdk(self, sdk): self.sdk = sdk
- def set_shared_state(self, state): self.shared_state = state
- def set_command_queue(self, queue): self.command_queue = queue
- def set_capabilities(self, caps): self.capabilities = caps
- def initialize(self): pass # Called after all setters
-```
-
-### Standard Components
-
-| Component | Purpose | Key Methods |
-|-----------|---------|-------------|
-| `StandardMotionComponent` | Joint/cartesian motion | `move_joint()`, `move_joint_velocity()`, `get_joint_state()`, `stop_motion()` |
-| `StandardServoComponent` | Motor control | `enable_servo()`, `disable_servo()`, `emergency_stop()`, `set_control_mode()` |
-| `StandardStatusComponent` | Monitoring | `get_robot_state()`, `get_error_state()`, `get_health_metrics()` |
-
-## Threading Model
-
-The driver runs **2 threads**:
-
-1. **Control Loop (100Hz)** - Process commands, read joint state, publish feedback
-2. **Monitor Loop (10Hz)** - Read robot state, errors, optional sensors (F/T, gripper)
-
-```
-RPC Call → Command Queue → Control Loop → SDK → Hardware
- ↓
- SharedState (thread-safe)
- ↓
- LCM Publisher → External Systems
-```
-
-## DIMOS Module Integration
-
-The driver extends `Module` for pub/sub integration:
-
-```python
-class BaseManipulatorDriver(Module):
- def __init__(self, sdk, components, config, capabilities):
- super().__init__()
- self.shared_state = SharedState()
- self.command_queue = Queue(maxsize=10)
-
- # Inject dependencies into components
- for component in components:
- component.set_sdk(sdk)
- component.set_shared_state(self.shared_state)
- component.set_command_queue(self.command_queue)
- component.set_capabilities(capabilities)
- component.initialize()
-
- # Auto-expose @component_api methods
- self._auto_expose_component_apis()
-```
-
-## Adding a New Manipulator
-
-### Step 1: SDK Wrapper
-
-```python
-class YourArmSDKWrapper(BaseManipulatorSDK):
- def get_joint_positions(self) -> list[float]:
- degrees = self._sdk.get_angles()
- return [math.radians(d) for d in degrees] # Convert to radians
-
- def set_joint_positions(self, positions, velocity, acceleration) -> bool:
- return self._sdk.move_joints(positions, velocity)
-
- # ... implement remaining required methods
-```
-
-### Step 2: Driver Assembly
-
-```python
-class YourArmDriver(BaseManipulatorDriver):
- def __init__(self, config: dict):
- sdk = YourArmSDKWrapper()
- capabilities = ManipulatorCapabilities(
- dof=6,
- joint_limits_lower=[-3.14] * 6,
- joint_limits_upper=[3.14] * 6,
- )
- components = [
- StandardMotionComponent(),
- StandardServoComponent(),
- StandardStatusComponent(),
- ]
- super().__init__(sdk, components, config, capabilities)
-```
-
-## Unit Conventions
-
-All SDK wrappers must convert to standard units:
-
-| Quantity | Unit |
-|----------|------|
-| Positions | radians |
-| Velocities | rad/s |
-| Accelerations | rad/s^2 |
-| Torques | Nm |
-| Cartesian | meters |
-
-## Testing Strategy
-
-```python
-# Test SDK wrapper with mocked native SDK
-def test_wrapper_positions():
- mock = Mock()
- mock.get_angles.return_value = [0, 90, 180]
- wrapper = YourArmSDKWrapper()
- wrapper._sdk = mock
- assert wrapper.get_joint_positions() == [0, math.pi/2, math.pi]
-
-# Test component with mocked SDK wrapper
-def test_motion_component():
- mock_sdk = Mock(spec=BaseManipulatorSDK)
- component = StandardMotionComponent()
- component.set_sdk(mock_sdk)
- component.move_joint([0, 0, 0])
- # Verify command was queued
-```
-
-## Advantages
-
-- **Maximum reuse**: Components tested once, used by 100+ arms
-- **Consistent behavior**: All arms identical at RPC level
-- **Centralized fixes**: Fix once in component, all arms benefit
-- **Team scalability**: Developers work on wrappers independently
-- **Strong contracts**: SDK interface defines exact requirements
-
-## Reference Implementations
-
-- **XArm**: `xarm/xarm_wrapper.py` - Full-featured, converts degrees→radians
-- **Piper**: `piper/piper_wrapper.py` - Shows velocity integration workaround
diff --git a/dimos/hardware/manipulators/base/components/__init__.py b/dimos/hardware/manipulators/base/components/__init__.py
deleted file mode 100644
index b04f60f691..0000000000
--- a/dimos/hardware/manipulators/base/components/__init__.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Standard components for manipulator drivers."""
-
-from collections.abc import Callable
-from typing import Any, TypeVar
-
-F = TypeVar("F", bound=Callable[..., Any])
-
-
-def component_api(fn: F) -> F:
- """Decorator to mark component methods that should be exposed as driver RPCs.
-
- Methods decorated with @component_api will be automatically discovered by the
- driver and exposed as @rpc methods on the driver instance. This allows external
- code to call these methods via the standard Module RPC system.
-
- Example:
- class MyComponent:
- @component_api
- def enable_servo(self):
- '''Enable servo motors.'''
- return self.sdk.enable_servos()
-
- # The driver will auto-generate:
- # @rpc
- # def enable_servo(self):
- # return component.enable_servo()
-
- # External code can then call:
- # driver.enable_servo()
- """
- fn.__component_api__ = True # type: ignore[attr-defined]
- return fn
-
-
-# Import components AFTER defining component_api to avoid circular imports
-from .motion import StandardMotionComponent
-from .servo import StandardServoComponent
-from .status import StandardStatusComponent
-
-__all__ = [
- "StandardMotionComponent",
- "StandardServoComponent",
- "StandardStatusComponent",
- "component_api",
-]
diff --git a/dimos/hardware/manipulators/base/components/motion.py b/dimos/hardware/manipulators/base/components/motion.py
deleted file mode 100644
index f3205acb01..0000000000
--- a/dimos/hardware/manipulators/base/components/motion.py
+++ /dev/null
@@ -1,591 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Standard motion control component for manipulator drivers."""
-
-import logging
-from queue import Queue
-import time
-from typing import Any
-
-from ..driver import Command
-from ..sdk_interface import BaseManipulatorSDK
-from ..spec import ManipulatorCapabilities
-from ..utils import SharedState, scale_velocities, validate_joint_limits, validate_velocity_limits
-from . import component_api
-
-
-class StandardMotionComponent:
- """Motion control component that works with any SDK wrapper.
-
- This component provides standard motion control methods that work
- consistently across all manipulator types. Methods decorated with @component_api
- are automatically exposed as RPC methods on the driver. It handles:
- - Joint position control
- - Joint velocity control
- - Joint effort/torque control (if supported)
- - Trajectory execution (if supported)
- - Motion safety validation
- """
-
- def __init__(
- self,
- sdk: BaseManipulatorSDK | None = None,
- shared_state: SharedState | None = None,
- command_queue: Queue[Any] | None = None,
- capabilities: ManipulatorCapabilities | None = None,
- ) -> None:
- """Initialize the motion component.
-
- Args:
- sdk: SDK wrapper instance (can be set later)
- shared_state: Shared state instance (can be set later)
- command_queue: Command queue (can be set later)
- capabilities: Manipulator capabilities (can be set later)
- """
- self.sdk = sdk
- self.shared_state = shared_state
- self.command_queue = command_queue
- self.capabilities = capabilities
- self.logger = logging.getLogger(self.__class__.__name__)
-
- # Motion limits
- self.velocity_scale = 1.0 # Global velocity scaling (0-1)
- self.acceleration_scale = 1.0 # Global acceleration scaling (0-1)
-
- # ============= Initialization Methods (called by BaseDriver) =============
-
- def set_sdk(self, sdk: BaseManipulatorSDK) -> None:
- """Set the SDK wrapper instance."""
- self.sdk = sdk
-
- def set_shared_state(self, shared_state: SharedState) -> None:
- """Set the shared state instance."""
- self.shared_state = shared_state
-
- def set_command_queue(self, command_queue: "Queue[Any]") -> None:
- """Set the command queue instance."""
- self.command_queue = command_queue
-
- def set_capabilities(self, capabilities: ManipulatorCapabilities) -> None:
- """Set the capabilities instance."""
- self.capabilities = capabilities
-
- def initialize(self) -> None:
- """Initialize the component after all resources are set."""
- self.logger.debug("Motion component initialized")
-
- # ============= Component API Methods =============
-
- @component_api
- def move_joint(
- self,
- positions: list[float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- validate: bool = True,
- ) -> dict[str, Any]:
- """Move joints to target positions.
-
- Args:
- positions: Target joint positions in radians
- velocity: Velocity scaling factor (0-1)
- acceleration: Acceleration scaling factor (0-1)
- wait: If True, block until motion completes
- validate: If True, validate against joint limits
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # Validate inputs
- if validate and self.capabilities:
- if len(positions) != self.capabilities.dof:
- return {
- "success": False,
- "error": f"Expected {self.capabilities.dof} positions, got {len(positions)}",
- }
-
- # Check joint limits
- if self.capabilities.joint_limits_lower and self.capabilities.joint_limits_upper:
- valid, error = validate_joint_limits(
- positions,
- self.capabilities.joint_limits_lower,
- self.capabilities.joint_limits_upper,
- )
- if not valid:
- return {"success": False, "error": error}
-
- # Apply global scaling
- velocity = velocity * self.velocity_scale
- acceleration = acceleration * self.acceleration_scale
-
- # Queue command for async execution
- if self.command_queue and not wait:
- command = Command(
- type="position",
- data={
- "positions": positions,
- "velocity": velocity,
- "acceleration": acceleration,
- "wait": False,
- },
- timestamp=time.time(),
- )
- self.command_queue.put(command)
- return {"success": True, "queued": True}
-
- # Execute directly (blocking or wait mode)
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- success = self.sdk.set_joint_positions(positions, velocity, acceleration, wait)
-
- if success and self.shared_state:
- self.shared_state.set_target_joints(positions=positions)
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in move_joint: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def move_joint_velocity(
- self, velocities: list[float], acceleration: float = 1.0, validate: bool = True
- ) -> dict[str, Any]:
- """Set joint velocities.
-
- Args:
- velocities: Target joint velocities in rad/s
- acceleration: Acceleration scaling factor (0-1)
- validate: If True, validate against velocity limits
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # Validate inputs
- if validate and self.capabilities:
- if len(velocities) != self.capabilities.dof:
- return {
- "success": False,
- "error": f"Expected {self.capabilities.dof} velocities, got {len(velocities)}",
- }
-
- # Check velocity limits
- if self.capabilities.max_joint_velocity:
- valid, _error = validate_velocity_limits(
- velocities, self.capabilities.max_joint_velocity, self.velocity_scale
- )
- if not valid:
- # Scale velocities to stay within limits
- velocities = scale_velocities(
- velocities, self.capabilities.max_joint_velocity, self.velocity_scale
- )
- self.logger.warning("Velocities scaled to stay within limits")
-
- # Queue command for async execution
- if self.command_queue:
- command = Command(
- type="velocity", data={"velocities": velocities}, timestamp=time.time()
- )
- self.command_queue.put(command)
- return {"success": True, "queued": True}
-
- # Execute directly
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- success = self.sdk.set_joint_velocities(velocities)
-
- if success and self.shared_state:
- self.shared_state.set_target_joints(velocities=velocities)
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in move_joint_velocity: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def move_joint_effort(self, efforts: list[float], validate: bool = True) -> dict[str, Any]:
- """Set joint efforts/torques.
-
- Args:
- efforts: Target joint efforts in Nm
- validate: If True, validate inputs
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # Check if effort control is supported
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- if not hasattr(self.sdk, "set_joint_efforts"):
- return {"success": False, "error": "Effort control not supported"}
-
- # Validate inputs
- if validate and self.capabilities:
- if len(efforts) != self.capabilities.dof:
- return {
- "success": False,
- "error": f"Expected {self.capabilities.dof} efforts, got {len(efforts)}",
- }
-
- # Queue command for async execution
- if self.command_queue:
- command = Command(type="effort", data={"efforts": efforts}, timestamp=time.time())
- self.command_queue.put(command)
- return {"success": True, "queued": True}
-
- # Execute directly
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- success = self.sdk.set_joint_efforts(efforts)
-
- if success and self.shared_state:
- self.shared_state.set_target_joints(efforts=efforts)
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in move_joint_effort: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def stop_motion(self) -> dict[str, Any]:
- """Stop all ongoing motion immediately.
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # Queue stop command with high priority
- if self.command_queue:
- command = Command(type="stop", data={}, timestamp=time.time())
- # Clear queue and add stop command
- while not self.command_queue.empty():
- try:
- self.command_queue.get_nowait()
- except:
- break
- self.command_queue.put(command)
-
- # Also execute directly for immediate stop
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- success = self.sdk.stop_motion()
-
- # Clear targets
- if self.shared_state:
- self.shared_state.set_target_joints(positions=None, velocities=None, efforts=None)
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in stop_motion: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_joint_state(self) -> dict[str, Any]:
- """Get current joint state.
-
- Returns:
- Dict with joint positions, velocities, efforts, and timestamp
- """
- try:
- if self.shared_state:
- # Get from shared state (updated by reader thread)
- positions, velocities, efforts = self.shared_state.get_joint_state()
- else:
- # Get directly from SDK
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- positions = self.sdk.get_joint_positions()
- velocities = self.sdk.get_joint_velocities()
- efforts = self.sdk.get_joint_efforts()
-
- return {
- "positions": positions,
- "velocities": velocities,
- "efforts": efforts,
- "timestamp": time.time(),
- "success": True,
- }
-
- except Exception as e:
- self.logger.error(f"Error in get_joint_state: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_joint_limits(self) -> dict[str, Any]:
- """Get joint position limits.
-
- Returns:
- Dict with lower and upper limits in radians
- """
- try:
- if self.capabilities:
- return {
- "lower": self.capabilities.joint_limits_lower,
- "upper": self.capabilities.joint_limits_upper,
- "success": True,
- }
- else:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- lower, upper = self.sdk.get_joint_limits()
- return {"lower": lower, "upper": upper, "success": True}
-
- except Exception as e:
- self.logger.error(f"Error in get_joint_limits: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_velocity_limits(self) -> dict[str, Any]:
- """Get joint velocity limits.
-
- Returns:
- Dict with maximum velocities in rad/s
- """
- try:
- if self.capabilities:
- return {"limits": self.capabilities.max_joint_velocity, "success": True}
- else:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- limits = self.sdk.get_velocity_limits()
- return {"limits": limits, "success": True}
-
- except Exception as e:
- self.logger.error(f"Error in get_velocity_limits: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def set_velocity_scale(self, scale: float) -> dict[str, Any]:
- """Set global velocity scaling factor.
-
- Args:
- scale: Velocity scale factor (0-1)
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if scale <= 0 or scale > 1:
- return {"success": False, "error": f"Invalid scale {scale}, must be in (0, 1]"}
-
- self.velocity_scale = scale
- return {"success": True, "scale": scale}
-
- except Exception as e:
- self.logger.error(f"Error in set_velocity_scale: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def set_acceleration_scale(self, scale: float) -> dict[str, Any]:
- """Set global acceleration scaling factor.
-
- Args:
- scale: Acceleration scale factor (0-1)
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if scale <= 0 or scale > 1:
- return {"success": False, "error": f"Invalid scale {scale}, must be in (0, 1]"}
-
- self.acceleration_scale = scale
- return {"success": True, "scale": scale}
-
- except Exception as e:
- self.logger.error(f"Error in set_acceleration_scale: {e}")
- return {"success": False, "error": str(e)}
-
- # ============= Cartesian Control (Optional) =============
-
- @component_api
- def move_cartesian(
- self,
- pose: dict[str, float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> dict[str, Any]:
- """Move end-effector to target pose.
-
- Args:
- pose: Target pose with keys: x, y, z (meters), roll, pitch, yaw (radians)
- velocity: Velocity scaling factor (0-1)
- acceleration: Acceleration scaling factor (0-1)
- wait: If True, block until motion completes
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # Check if Cartesian control is supported
- if not self.capabilities or not self.capabilities.has_cartesian_control:
- return {"success": False, "error": "Cartesian control not supported"}
-
- # Apply global scaling
- velocity = velocity * self.velocity_scale
- acceleration = acceleration * self.acceleration_scale
-
- # Queue command for async execution
- if self.command_queue and not wait:
- command = Command(
- type="cartesian",
- data={
- "pose": pose,
- "velocity": velocity,
- "acceleration": acceleration,
- "wait": False,
- },
- timestamp=time.time(),
- )
- self.command_queue.put(command)
- return {"success": True, "queued": True}
-
- # Execute directly
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- success = self.sdk.set_cartesian_position(pose, velocity, acceleration, wait)
-
- if success and self.shared_state:
- self.shared_state.target_cartesian_position = pose
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in move_cartesian: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_cartesian_state(self) -> dict[str, Any]:
- """Get current end-effector pose.
-
- Returns:
- Dict with pose (x, y, z, roll, pitch, yaw) and timestamp
- """
- try:
- # Check if Cartesian control is supported
- if not self.capabilities or not self.capabilities.has_cartesian_control:
- return {"success": False, "error": "Cartesian control not supported"}
-
- pose: dict[str, float] | None = None
- if self.shared_state and self.shared_state.cartesian_position:
- # Get from shared state
- pose = self.shared_state.cartesian_position
- else:
- # Get directly from SDK
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- pose = self.sdk.get_cartesian_position()
-
- if pose:
- return {"pose": pose, "timestamp": time.time(), "success": True}
- else:
- return {"success": False, "error": "Failed to get Cartesian state"}
-
- except Exception as e:
- self.logger.error(f"Error in get_cartesian_state: {e}")
- return {"success": False, "error": str(e)}
-
- # ============= Trajectory Execution (Optional) =============
-
- @component_api
- def execute_trajectory(
- self, trajectory: list[dict[str, Any]], wait: bool = True
- ) -> dict[str, Any]:
- """Execute a joint trajectory.
-
- Args:
- trajectory: List of waypoints, each with:
- - 'positions': list[float] in radians
- - 'velocities': Optional list[float] in rad/s
- - 'time': float seconds from start
- wait: If True, block until trajectory completes
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # Check if trajectory execution is supported
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- if not hasattr(self.sdk, "execute_trajectory"):
- return {"success": False, "error": "Trajectory execution not supported"}
-
- # Validate trajectory if capabilities available
- if self.capabilities:
- from ..utils import validate_trajectory
-
- # Only validate if all required capability fields are present
- jl_lower = self.capabilities.joint_limits_lower
- jl_upper = self.capabilities.joint_limits_upper
- max_vel = self.capabilities.max_joint_velocity
- max_acc = self.capabilities.max_joint_acceleration
-
- if (
- jl_lower is not None
- and jl_upper is not None
- and max_vel is not None
- and max_acc is not None
- ):
- valid, error = validate_trajectory(
- trajectory,
- jl_lower,
- jl_upper,
- max_vel,
- max_acc,
- )
- if not valid:
- return {"success": False, "error": error}
- else:
- self.logger.debug("Skipping trajectory validation; capabilities incomplete")
-
- # Execute trajectory
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- success = self.sdk.execute_trajectory(trajectory, wait)
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in execute_trajectory: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def stop_trajectory(self) -> dict[str, Any]:
- """Stop any executing trajectory.
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # Check if trajectory execution is supported
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
- if not hasattr(self.sdk, "stop_trajectory"):
- return {"success": False, "error": "Trajectory execution not supported"}
-
- success = self.sdk.stop_trajectory()
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in stop_trajectory: {e}")
- return {"success": False, "error": str(e)}
diff --git a/dimos/hardware/manipulators/base/components/servo.py b/dimos/hardware/manipulators/base/components/servo.py
deleted file mode 100644
index c773f10723..0000000000
--- a/dimos/hardware/manipulators/base/components/servo.py
+++ /dev/null
@@ -1,522 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Standard servo control component for manipulator drivers."""
-
-import logging
-import time
-from typing import Any
-
-from ..sdk_interface import BaseManipulatorSDK
-from ..spec import ManipulatorCapabilities
-from ..utils import SharedState
-from . import component_api
-
-
-class StandardServoComponent:
- """Servo control component that works with any SDK wrapper.
-
- This component provides standard servo/motor control methods that work
- consistently across all manipulator types. Methods decorated with @component_api
- are automatically exposed as RPC methods on the driver. It handles:
- - Servo enable/disable
- - Control mode switching
- - Emergency stop
- - Error recovery
- - Homing operations
- """
-
- def __init__(
- self,
- sdk: BaseManipulatorSDK | None = None,
- shared_state: SharedState | None = None,
- capabilities: ManipulatorCapabilities | None = None,
- ):
- """Initialize the servo component.
-
- Args:
- sdk: SDK wrapper instance (can be set later)
- shared_state: Shared state instance (can be set later)
- capabilities: Manipulator capabilities (can be set later)
- """
- self.sdk = sdk
- self.shared_state = shared_state
- self.capabilities = capabilities
- self.logger = logging.getLogger(self.__class__.__name__)
-
- # State tracking
- self.last_enable_time = 0.0
- self.last_disable_time = 0.0
-
- # ============= Initialization Methods (called by BaseDriver) =============
-
- def set_sdk(self, sdk: BaseManipulatorSDK) -> None:
- """Set the SDK wrapper instance."""
- self.sdk = sdk
-
- def set_shared_state(self, shared_state: SharedState) -> None:
- """Set the shared state instance."""
- self.shared_state = shared_state
-
- def set_capabilities(self, capabilities: ManipulatorCapabilities) -> None:
- """Set the capabilities instance."""
- self.capabilities = capabilities
-
- def initialize(self) -> None:
- """Initialize the component after all resources are set."""
- self.logger.debug("Servo component initialized")
-
- # ============= Component API Methods =============
-
- @component_api
- def enable_servo(self, check_errors: bool = True) -> dict[str, Any]:
- """Enable servo/motor control.
-
- Args:
- check_errors: If True, check for errors before enabling
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Check if already enabled
- if self.sdk.are_servos_enabled():
- return {"success": True, "message": "Servos already enabled"}
-
- # Check for errors if requested
- if check_errors:
- error_code = self.sdk.get_error_code()
- if error_code != 0:
- error_msg = self.sdk.get_error_message()
- return {
- "success": False,
- "error": f"Cannot enable servos with active error: {error_msg} (code: {error_code})",
- }
-
- # Enable servos
- success = self.sdk.enable_servos()
-
- if success:
- self.last_enable_time = time.time()
- if self.shared_state:
- self.shared_state.is_enabled = True
- self.logger.info("Servos enabled successfully")
- else:
- self.logger.error("Failed to enable servos")
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in enable_servo: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def disable_servo(self, stop_motion: bool = True) -> dict[str, Any]:
- """Disable servo/motor control.
-
- Args:
- stop_motion: If True, stop any ongoing motion first
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Check if already disabled
- if not self.sdk.are_servos_enabled():
- return {"success": True, "message": "Servos already disabled"}
-
- # Stop motion if requested
- if stop_motion:
- self.sdk.stop_motion()
- time.sleep(0.1) # Brief delay to ensure motion stopped
-
- # Disable servos
- success = self.sdk.disable_servos()
-
- if success:
- self.last_disable_time = time.time()
- if self.shared_state:
- self.shared_state.is_enabled = False
- self.logger.info("Servos disabled successfully")
- else:
- self.logger.error("Failed to disable servos")
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in disable_servo: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def toggle_servo(self) -> dict[str, Any]:
- """Toggle servo enable/disable state.
-
- Returns:
- Dict with 'success', 'enabled' state, and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- current_state = self.sdk.are_servos_enabled()
-
- if current_state:
- result = self.disable_servo()
- else:
- result = self.enable_servo()
-
- if result["success"]:
- result["enabled"] = not current_state
-
- return result
-
- except Exception as e:
- self.logger.error(f"Error in toggle_servo: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_servo_state(self) -> dict[str, Any]:
- """Get current servo state.
-
- Returns:
- Dict with servo state information
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- enabled = self.sdk.are_servos_enabled()
- robot_state = self.sdk.get_robot_state()
-
- return {
- "enabled": enabled,
- "mode": robot_state.get("mode", 0),
- "state": robot_state.get("state", 0),
- "is_moving": robot_state.get("is_moving", False),
- "last_enable_time": self.last_enable_time,
- "last_disable_time": self.last_disable_time,
- "success": True,
- }
-
- except Exception as e:
- self.logger.error(f"Error in get_servo_state: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def emergency_stop(self) -> dict[str, Any]:
- """Execute emergency stop.
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Execute e-stop
- success = self.sdk.emergency_stop()
-
- if success:
- # Update shared state
- if self.shared_state:
- self.shared_state.update_robot_state(state=3) # 3 = e-stop state
- self.shared_state.is_enabled = False
- self.shared_state.is_moving = False
-
- self.logger.warning("Emergency stop executed")
- else:
- self.logger.error("Failed to execute emergency stop")
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in emergency_stop: {e}")
- # Try to stop motion as fallback
- try:
- if self.sdk is not None:
- self.sdk.stop_motion()
- except:
- pass
- return {"success": False, "error": str(e)}
-
- @component_api
- def reset_emergency_stop(self) -> dict[str, Any]:
- """Reset from emergency stop state.
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Clear errors first
- self.sdk.clear_errors()
-
- # Re-enable servos
- success = self.sdk.enable_servos()
-
- if success:
- if self.shared_state:
- self.shared_state.update_robot_state(state=0) # 0 = idle
- self.shared_state.is_enabled = True
-
- self.logger.info("Emergency stop reset successfully")
- else:
- self.logger.error("Failed to reset emergency stop")
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in reset_emergency_stop: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def set_control_mode(self, mode: str) -> dict[str, Any]:
- """Set control mode.
-
- Args:
- mode: Control mode ('position', 'velocity', 'torque', 'impedance')
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Validate mode
- valid_modes = ["position", "velocity", "torque", "impedance"]
- if mode not in valid_modes:
- return {
- "success": False,
- "error": f"Invalid mode '{mode}'. Valid modes: {valid_modes}",
- }
-
- # Check if mode is supported
- if mode == "impedance" and self.capabilities:
- if not self.capabilities.has_impedance_control:
- return {"success": False, "error": "Impedance control not supported"}
-
- # Set control mode
- success = self.sdk.set_control_mode(mode)
-
- if success:
- # Map mode string to integer
- mode_map = {"position": 0, "velocity": 1, "torque": 2, "impedance": 3}
- if self.shared_state:
- self.shared_state.update_robot_state(mode=mode_map.get(mode, 0))
-
- self.logger.info(f"Control mode set to '{mode}'")
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in set_control_mode: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_control_mode(self) -> dict[str, Any]:
- """Get current control mode.
-
- Returns:
- Dict with current mode and success status
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- mode = self.sdk.get_control_mode()
-
- if mode:
- return {"mode": mode, "success": True}
- else:
- # Try to get from robot state
- robot_state = self.sdk.get_robot_state()
- mode_int = robot_state.get("mode", 0)
-
- # Map integer to string
- mode_map = {0: "position", 1: "velocity", 2: "torque", 3: "impedance"}
- mode_str = mode_map.get(mode_int, "unknown")
-
- return {"mode": mode_str, "success": True}
-
- except Exception as e:
- self.logger.error(f"Error in get_control_mode: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def clear_errors(self) -> dict[str, Any]:
- """Clear any error states.
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Clear errors via SDK
- success = self.sdk.clear_errors()
-
- if success:
- # Update shared state
- if self.shared_state:
- self.shared_state.clear_errors()
-
- self.logger.info("Errors cleared successfully")
- else:
- self.logger.error("Failed to clear errors")
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in clear_errors: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def reset_fault(self) -> dict[str, Any]:
- """Reset from fault state.
-
- This typically involves:
- 1. Clearing errors
- 2. Disabling servos
- 3. Brief delay
- 4. Re-enabling servos
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- self.logger.info("Resetting fault state...")
-
- # Step 1: Clear errors
- if not self.sdk.clear_errors():
- return {"success": False, "error": "Failed to clear errors"}
-
- # Step 2: Disable servos if enabled
- if self.sdk.are_servos_enabled():
- if not self.sdk.disable_servos():
- return {"success": False, "error": "Failed to disable servos"}
-
- # Step 3: Brief delay
- time.sleep(0.5)
-
- # Step 4: Re-enable servos
- if not self.sdk.enable_servos():
- return {"success": False, "error": "Failed to re-enable servos"}
-
- # Update shared state
- if self.shared_state:
- self.shared_state.update_robot_state(
- state=0, # idle
- error_code=0,
- error_message="",
- )
- self.shared_state.is_enabled = True
-
- self.logger.info("Fault reset successfully")
- return {"success": True}
-
- except Exception as e:
- self.logger.error(f"Error in reset_fault: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def home_robot(self, position: list[float] | None = None) -> dict[str, Any]:
- """Move robot to home position.
-
- Args:
- position: Optional home position in radians.
- If None, uses zero position or configured home.
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Determine home position
- if position is None:
- # Use configured home or zero position
- if self.capabilities:
- position = [0.0] * self.capabilities.dof
- else:
- # Get current DOF from joint state
- current = self.sdk.get_joint_positions()
- position = [0.0] * len(current)
-
- # Enable servos if needed
- if not self.sdk.are_servos_enabled():
- if not self.sdk.enable_servos():
- return {"success": False, "error": "Failed to enable servos"}
-
- # Move to home position
- success = self.sdk.set_joint_positions(
- position,
- velocity=0.3, # Slower speed for homing
- acceleration=0.3,
- wait=True, # Wait for completion
- )
-
- if success:
- if self.shared_state:
- self.shared_state.is_homed = True
- self.logger.info("Robot homed successfully")
-
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in home_robot: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def brake_release(self) -> dict[str, Any]:
- """Release motor brakes (if applicable).
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # This is typically the same as enabling servos
- return self.enable_servo()
-
- except Exception as e:
- self.logger.error(f"Error in brake_release: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def brake_engage(self) -> dict[str, Any]:
- """Engage motor brakes (if applicable).
-
- Returns:
- Dict with 'success' and optional 'error' keys
- """
- try:
- # This is typically the same as disabling servos
- return self.disable_servo()
-
- except Exception as e:
- self.logger.error(f"Error in brake_engage: {e}")
- return {"success": False, "error": str(e)}
diff --git a/dimos/hardware/manipulators/base/components/status.py b/dimos/hardware/manipulators/base/components/status.py
deleted file mode 100644
index b20897ac65..0000000000
--- a/dimos/hardware/manipulators/base/components/status.py
+++ /dev/null
@@ -1,595 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Standard status monitoring component for manipulator drivers."""
-
-from collections import deque
-from dataclasses import dataclass
-import logging
-import time
-from typing import Any
-
-from ..sdk_interface import BaseManipulatorSDK
-from ..spec import ManipulatorCapabilities
-from ..utils import SharedState
-from . import component_api
-
-
-@dataclass
-class HealthMetrics:
- """Health metrics for monitoring."""
-
- update_rate: float = 0.0 # Hz
- command_rate: float = 0.0 # Hz
- error_rate: float = 0.0 # errors/minute
- uptime: float = 0.0 # seconds
- total_errors: int = 0
- total_commands: int = 0
- total_updates: int = 0
-
-
-class StandardStatusComponent:
- """Status monitoring component that works with any SDK wrapper.
-
- This component provides standard status monitoring methods that work
- consistently across all manipulator types. Methods decorated with @component_api
- are automatically exposed as RPC methods on the driver. It handles:
- - Robot state queries
- - Error monitoring
- - Health metrics
- - System information
- - Force/torque monitoring (if supported)
- - Temperature monitoring (if supported)
- """
-
- def __init__(
- self,
- sdk: BaseManipulatorSDK | None = None,
- shared_state: SharedState | None = None,
- capabilities: ManipulatorCapabilities | None = None,
- ):
- """Initialize the status component.
-
- Args:
- sdk: SDK wrapper instance (can be set later)
- shared_state: Shared state instance (can be set later)
- capabilities: Manipulator capabilities (can be set later)
- """
- self.sdk = sdk
- self.shared_state = shared_state
- self.capabilities = capabilities
- self.logger = logging.getLogger(self.__class__.__name__)
-
- # Health monitoring
- self.start_time = time.time()
- self.health_metrics = HealthMetrics()
-
- # Rate calculation
- self.update_timestamps: deque[float] = deque(maxlen=100)
- self.command_timestamps: deque[float] = deque(maxlen=100)
- self.error_timestamps: deque[float] = deque(maxlen=100)
-
- # Error history
- self.error_history: deque[dict[str, Any]] = deque(maxlen=50)
-
- # ============= Initialization Methods (called by BaseDriver) =============
-
- def set_sdk(self, sdk: BaseManipulatorSDK) -> None:
- """Set the SDK wrapper instance."""
- self.sdk = sdk
-
- def set_shared_state(self, shared_state: SharedState) -> None:
- """Set the shared state instance."""
- self.shared_state = shared_state
-
- def set_capabilities(self, capabilities: ManipulatorCapabilities) -> None:
- """Set the capabilities instance."""
- self.capabilities = capabilities
-
- def initialize(self) -> None:
- """Initialize the component after all resources are set."""
- self.start_time = time.time()
- self.logger.debug("Status component initialized")
-
- def publish_state(self) -> None:
- """Called periodically to update metrics (by publisher thread)."""
- current_time = time.time()
- self.update_timestamps.append(current_time)
- self._update_health_metrics()
-
- # ============= Component API Methods =============
-
- @component_api
- def get_robot_state(self) -> dict[str, Any]:
- """Get comprehensive robot state.
-
- Returns:
- Dict with complete state information
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- current_time = time.time()
-
- # Get state from SDK
- robot_state = self.sdk.get_robot_state()
-
- # Get additional info
- error_msg = (
- self.sdk.get_error_message() if robot_state.get("error_code", 0) != 0 else ""
- )
-
- # Map state integer to string
- state_map = {0: "idle", 1: "moving", 2: "error", 3: "emergency_stop"}
- state_str = state_map.get(robot_state.get("state", 0), "unknown")
-
- # Map mode integer to string
- mode_map = {0: "position", 1: "velocity", 2: "torque", 3: "impedance"}
- mode_str = mode_map.get(robot_state.get("mode", 0), "unknown")
-
- result = {
- "state": state_str,
- "state_code": robot_state.get("state", 0),
- "mode": mode_str,
- "mode_code": robot_state.get("mode", 0),
- "error_code": robot_state.get("error_code", 0),
- "error_message": error_msg,
- "is_moving": robot_state.get("is_moving", False),
- "is_connected": self.sdk.is_connected(),
- "is_enabled": self.sdk.are_servos_enabled(),
- "timestamp": current_time,
- "success": True,
- }
-
- # Add shared state info if available
- if self.shared_state:
- result["is_homed"] = self.shared_state.is_homed
- result["last_update"] = self.shared_state.last_state_update
- result["last_command"] = self.shared_state.last_command_sent
-
- return result
-
- except Exception as e:
- self.logger.error(f"Error in get_robot_state: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_system_info(self) -> dict[str, Any]:
- """Get system information.
-
- Returns:
- Dict with system information
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- # Get manipulator info
- info = self.sdk.get_info()
-
- result = {
- "vendor": info.vendor,
- "model": info.model,
- "dof": info.dof,
- "firmware_version": info.firmware_version,
- "serial_number": info.serial_number,
- "success": True,
- }
-
- # Add capabilities if available
- if self.capabilities:
- result["capabilities"] = {
- "dof": self.capabilities.dof,
- "has_gripper": self.capabilities.has_gripper,
- "has_force_torque": self.capabilities.has_force_torque,
- "has_impedance_control": self.capabilities.has_impedance_control,
- "has_cartesian_control": self.capabilities.has_cartesian_control,
- "payload_mass": self.capabilities.payload_mass,
- "reach": self.capabilities.reach,
- }
-
- return result
-
- except Exception as e:
- self.logger.error(f"Error in get_system_info: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_capabilities(self) -> dict[str, Any]:
- """Get manipulator capabilities.
-
- Returns:
- Dict with capability information
- """
- try:
- if not self.capabilities:
- return {"success": False, "error": "Capabilities not available"}
-
- return {
- "dof": self.capabilities.dof,
- "has_gripper": self.capabilities.has_gripper,
- "has_force_torque": self.capabilities.has_force_torque,
- "has_impedance_control": self.capabilities.has_impedance_control,
- "has_cartesian_control": self.capabilities.has_cartesian_control,
- "joint_limits_lower": self.capabilities.joint_limits_lower,
- "joint_limits_upper": self.capabilities.joint_limits_upper,
- "max_joint_velocity": self.capabilities.max_joint_velocity,
- "max_joint_acceleration": self.capabilities.max_joint_acceleration,
- "payload_mass": self.capabilities.payload_mass,
- "reach": self.capabilities.reach,
- "success": True,
- }
-
- except Exception as e:
- self.logger.error(f"Error in get_capabilities: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_error_state(self) -> dict[str, Any]:
- """Get detailed error state.
-
- Returns:
- Dict with error information
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- error_code = self.sdk.get_error_code()
- error_msg = self.sdk.get_error_message()
-
- result = {
- "has_error": error_code != 0,
- "error_code": error_code,
- "error_message": error_msg,
- "error_history": list(self.error_history),
- "total_errors": self.health_metrics.total_errors,
- "success": True,
- }
-
- # Add last error time from shared state
- if self.shared_state:
- result["last_error_time"] = self.shared_state.last_error_time
-
- return result
-
- except Exception as e:
- self.logger.error(f"Error in get_error_state: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_health_metrics(self) -> dict[str, Any]:
- """Get health metrics.
-
- Returns:
- Dict with health metrics
- """
- try:
- self._update_health_metrics()
-
- return {
- "uptime": self.health_metrics.uptime,
- "update_rate": self.health_metrics.update_rate,
- "command_rate": self.health_metrics.command_rate,
- "error_rate": self.health_metrics.error_rate,
- "total_updates": self.health_metrics.total_updates,
- "total_commands": self.health_metrics.total_commands,
- "total_errors": self.health_metrics.total_errors,
- "is_healthy": self._is_healthy(),
- "timestamp": time.time(),
- "success": True,
- }
-
- except Exception as e:
- self.logger.error(f"Error in get_health_metrics: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_statistics(self) -> dict[str, Any]:
- """Get operation statistics.
-
- Returns:
- Dict with statistics
- """
- try:
- stats = {}
-
- # Get stats from shared state
- if self.shared_state:
- stats.update(self.shared_state.get_statistics())
-
- # Add component stats
- stats["uptime"] = time.time() - self.start_time
- stats["health_metrics"] = {
- "update_rate": self.health_metrics.update_rate,
- "command_rate": self.health_metrics.command_rate,
- "error_rate": self.health_metrics.error_rate,
- }
-
- stats["success"] = True
- return stats
-
- except Exception as e:
- self.logger.error(f"Error in get_statistics: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def check_connection(self) -> dict[str, Any]:
- """Check connection status.
-
- Returns:
- Dict with connection status
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- connected = self.sdk.is_connected()
-
- result: dict[str, Any] = {
- "connected": connected,
- "timestamp": time.time(),
- "success": True,
- }
-
- # Try to get more info if connected
- if connected:
- try:
- # Try a simple query to verify connection
- self.sdk.get_error_code()
- result["verified"] = True
- except:
- result["verified"] = False
- result["message"] = "Connected but cannot communicate"
-
- return result
-
- except Exception as e:
- self.logger.error(f"Error in check_connection: {e}")
- return {"success": False, "error": str(e)}
-
- # ============= Force/Torque Monitoring (Optional) =============
-
- @component_api
- def get_force_torque(self) -> dict[str, Any]:
- """Get force/torque sensor data.
-
- Returns:
- Dict with F/T data if available
- """
- try:
- # Check if F/T is supported
- if not self.capabilities or not self.capabilities.has_force_torque:
- return {"success": False, "error": "Force/torque sensor not available"}
-
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- ft_data = self.sdk.get_force_torque()
-
- if ft_data:
- return {
- "force": ft_data[:3] if len(ft_data) >= 3 else None, # [fx, fy, fz]
- "torque": ft_data[3:6] if len(ft_data) >= 6 else None, # [tx, ty, tz]
- "data": ft_data,
- "timestamp": time.time(),
- "success": True,
- }
- else:
- return {"success": False, "error": "Failed to read F/T sensor"}
-
- except Exception as e:
- self.logger.error(f"Error in get_force_torque: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def zero_force_torque(self) -> dict[str, Any]:
- """Zero the force/torque sensor.
-
- Returns:
- Dict with success status
- """
- try:
- # Check if F/T is supported
- if not self.capabilities or not self.capabilities.has_force_torque:
- return {"success": False, "error": "Force/torque sensor not available"}
-
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- success = self.sdk.zero_force_torque()
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in zero_force_torque: {e}")
- return {"success": False, "error": str(e)}
-
- # ============= I/O Monitoring (Optional) =============
-
- @component_api
- def get_digital_inputs(self) -> dict[str, Any]:
- """Get digital input states.
-
- Returns:
- Dict with digital input states
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- inputs = self.sdk.get_digital_inputs()
-
- if inputs is not None:
- return {"inputs": inputs, "timestamp": time.time(), "success": True}
- else:
- return {"success": False, "error": "Digital inputs not available"}
-
- except Exception as e:
- self.logger.error(f"Error in get_digital_inputs: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def set_digital_outputs(self, outputs: dict[str, bool]) -> dict[str, Any]:
- """Set digital output states.
-
- Args:
- outputs: Dict of output_id: bool
-
- Returns:
- Dict with success status
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- success = self.sdk.set_digital_outputs(outputs)
- return {"success": success}
-
- except Exception as e:
- self.logger.error(f"Error in set_digital_outputs: {e}")
- return {"success": False, "error": str(e)}
-
- @component_api
- def get_analog_inputs(self) -> dict[str, Any]:
- """Get analog input values.
-
- Returns:
- Dict with analog input values
- """
- try:
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- inputs = self.sdk.get_analog_inputs()
-
- if inputs is not None:
- return {"inputs": inputs, "timestamp": time.time(), "success": True}
- else:
- return {"success": False, "error": "Analog inputs not available"}
-
- except Exception as e:
- self.logger.error(f"Error in get_analog_inputs: {e}")
- return {"success": False, "error": str(e)}
-
- # ============= Gripper Status (Optional) =============
-
- @component_api
- def get_gripper_state(self) -> dict[str, Any]:
- """Get gripper state.
-
- Returns:
- Dict with gripper state
- """
- try:
- # Check if gripper is supported
- if not self.capabilities or not self.capabilities.has_gripper:
- return {"success": False, "error": "Gripper not available"}
-
- if self.sdk is None:
- return {"success": False, "error": "SDK not configured"}
-
- position = self.sdk.get_gripper_position()
-
- if position is not None:
- result: dict[str, Any] = {
- "position": position, # meters
- "timestamp": time.time(),
- "success": True,
- }
-
- # Add from shared state if available
- if self.shared_state and self.shared_state.gripper_force is not None:
- result["force"] = self.shared_state.gripper_force
-
- return result
- else:
- return {"success": False, "error": "Failed to get gripper state"}
-
- except Exception as e:
- self.logger.error(f"Error in get_gripper_state: {e}")
- return {"success": False, "error": str(e)}
-
- # ============= Helper Methods =============
-
- def _update_health_metrics(self) -> None:
- """Update health metrics based on recent data."""
- current_time = time.time()
-
- # Update uptime
- self.health_metrics.uptime = current_time - self.start_time
-
- # Calculate update rate
- if len(self.update_timestamps) > 1:
- time_span = self.update_timestamps[-1] - self.update_timestamps[0]
- if time_span > 0:
- self.health_metrics.update_rate = len(self.update_timestamps) / time_span
-
- # Calculate command rate
- if len(self.command_timestamps) > 1:
- time_span = self.command_timestamps[-1] - self.command_timestamps[0]
- if time_span > 0:
- self.health_metrics.command_rate = len(self.command_timestamps) / time_span
-
- # Calculate error rate (errors per minute)
- recent_errors = [t for t in self.error_timestamps if current_time - t < 60]
- self.health_metrics.error_rate = len(recent_errors)
-
- # Update totals from shared state
- if self.shared_state:
- stats = self.shared_state.get_statistics()
- self.health_metrics.total_updates = stats.get("state_read_count", 0)
- self.health_metrics.total_commands = stats.get("command_sent_count", 0)
- self.health_metrics.total_errors = stats.get("error_count", 0)
-
- def _is_healthy(self) -> bool:
- """Check if system is healthy based on metrics."""
- # Check update rate (should be > 10 Hz)
- if self.health_metrics.update_rate < 10:
- return False
-
- # Check error rate (should be < 10 per minute)
- if self.health_metrics.error_rate > 10:
- return False
-
- # Check SDK is configured
- if self.sdk is None:
- return False
-
- # Check connection
- if not self.sdk.is_connected():
- return False
-
- # Check for persistent errors
- if self.sdk.get_error_code() != 0:
- return False
-
- return True
-
- def record_error(self, error_code: int, error_msg: str) -> None:
- """Record an error occurrence.
-
- Args:
- error_code: Error code
- error_msg: Error message
- """
- current_time = time.time()
- self.error_timestamps.append(current_time)
- self.error_history.append(
- {"code": error_code, "message": error_msg, "timestamp": current_time}
- )
-
- def record_command(self) -> None:
- """Record a command occurrence."""
- self.command_timestamps.append(time.time())
diff --git a/dimos/hardware/manipulators/base/driver.py b/dimos/hardware/manipulators/base/driver.py
deleted file mode 100644
index be68be5a23..0000000000
--- a/dimos/hardware/manipulators/base/driver.py
+++ /dev/null
@@ -1,637 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Base manipulator driver with threading and component management."""
-
-from dataclasses import dataclass
-import logging
-from queue import Empty, Queue
-from threading import Event, Thread
-import time
-from typing import Any
-
-from dimos.core import In, Module, Out, rpc
-from dimos.msgs.geometry_msgs import WrenchStamped
-from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState
-
-from .sdk_interface import BaseManipulatorSDK
-from .spec import ManipulatorCapabilities
-from .utils import SharedState
-
-
-@dataclass
-class Command:
- """Command to be sent to the manipulator."""
-
- type: str # 'position', 'velocity', 'effort', 'cartesian', etc.
- data: Any
- timestamp: float = 0.0
-
-
-class BaseManipulatorDriver(Module):
- """Base driver providing threading and component management.
-
- This class handles:
- - Thread management (state reader, command sender, state publisher)
- - Component registration and lifecycle
- - RPC method registration
- - Shared state management
- - Error handling and recovery
- - Pub/Sub with LCM transport for real-time control
- """
-
- # Input topics (commands from controllers - initialized by Module)
- joint_position_command: In[JointCommand] = None # type: ignore[assignment]
- joint_velocity_command: In[JointCommand] = None # type: ignore[assignment]
-
- # Output topics (state publishing - initialized by Module)
- joint_state: Out[JointState] = None # type: ignore[assignment]
- robot_state: Out[RobotState] = None # type: ignore[assignment]
- ft_sensor: Out[WrenchStamped] = None # type: ignore[assignment]
-
- def __init__(
- self,
- sdk: BaseManipulatorSDK,
- components: list[Any],
- config: dict[str, Any],
- name: str | None = None,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- """Initialize the base manipulator driver.
-
- Args:
- sdk: SDK wrapper instance
- components: List of component instances
- config: Configuration dictionary
- name: Optional driver name for logging
- *args, **kwargs: Additional arguments for Module
- """
- # Initialize Module parent class
- super().__init__(*args, **kwargs)
-
- self.sdk = sdk
- self.components = components
- self.config: Any = config # Config dict accessed as object
- self.name = name or self.__class__.__name__
-
- # Logging
- self.logger = logging.getLogger(self.name)
-
- # Shared state
- self.shared_state = SharedState()
-
- # Threading
- self.stop_event = Event()
- self.threads: list[Thread] = []
- self.command_queue: Queue[Any] = Queue(maxsize=10)
-
- # RPC registry
- self.rpc_methods: dict[str, Any] = {}
- self._exposed_component_apis: set[str] = set() # Track auto-exposed method names
-
- # Capabilities
- self.capabilities = self._get_capabilities()
-
- # Rate control
- self.control_rate = config.get("control_rate", 100) # Hz - control loop + joint feedback
- self.monitor_rate = config.get("monitor_rate", 10) # Hz - robot state monitoring
-
- # Pre-allocate reusable objects (optimization: avoid per-cycle allocation)
- # Note: _joint_names is populated after _get_capabilities() sets self.capabilities
- self._joint_names: list[str] = [f"joint{i + 1}" for i in range(self.capabilities.dof)]
-
- # Initialize components with shared resources
- self._initialize_components()
-
- # Auto-expose component API methods as RPCs on the driver
- self._auto_expose_component_apis()
-
- # Connect to hardware
- self._connect()
-
- def _get_capabilities(self) -> ManipulatorCapabilities:
- """Get manipulator capabilities from config or SDK.
-
- Returns:
- ManipulatorCapabilities instance
- """
- # Try to get from SDK info
- info = self.sdk.get_info()
-
- # Get joint limits
- lower_limits, upper_limits = self.sdk.get_joint_limits()
- velocity_limits = self.sdk.get_velocity_limits()
- acceleration_limits = self.sdk.get_acceleration_limits()
-
- return ManipulatorCapabilities(
- dof=info.dof,
- has_gripper=self.config.get("has_gripper", False),
- has_force_torque=self.config.get("has_force_torque", False),
- has_impedance_control=self.config.get("has_impedance_control", False),
- has_cartesian_control=self.config.get("has_cartesian_control", False),
- max_joint_velocity=velocity_limits,
- max_joint_acceleration=acceleration_limits,
- joint_limits_lower=lower_limits,
- joint_limits_upper=upper_limits,
- payload_mass=self.config.get("payload_mass", 0.0),
- reach=self.config.get("reach", 0.0),
- )
-
- def _initialize_components(self) -> None:
- """Initialize components with shared resources."""
- for component in self.components:
- # Provide access to shared state
- if hasattr(component, "set_shared_state"):
- component.set_shared_state(self.shared_state)
-
- # Provide access to SDK
- if hasattr(component, "set_sdk"):
- component.set_sdk(self.sdk)
-
- # Provide access to command queue
- if hasattr(component, "set_command_queue"):
- component.set_command_queue(self.command_queue)
-
- # Provide access to capabilities
- if hasattr(component, "set_capabilities"):
- component.set_capabilities(self.capabilities)
-
- # Initialize component
- if hasattr(component, "initialize"):
- component.initialize()
-
- def _auto_expose_component_apis(self) -> None:
- """Auto-expose @component_api methods from components as RPC methods on the driver.
-
- This scans all components for methods decorated with @component_api and creates
- corresponding @rpc wrapper methods on the driver instance. This allows external
- code to call these methods via the standard Module RPC system.
-
- Example:
- # Component defines:
- @component_api
- def enable_servo(self): ...
-
- # Driver auto-generates an RPC wrapper, so external code can call:
- driver.enable_servo()
-
- # And the method is discoverable via:
- driver.rpcs # Lists 'enable_servo' among available RPCs
- """
- for component in self.components:
- for method_name in dir(component):
- if method_name.startswith("_"):
- continue
-
- method = getattr(component, method_name, None)
- if not callable(method) or not getattr(method, "__component_api__", False):
- continue
-
- # Skip if driver already has a non-wrapper method with this name
- existing = getattr(self, method_name, None)
- if existing is not None and not getattr(
- existing, "__component_api_wrapper__", False
- ):
- self.logger.warning(
- f"Driver already has method '{method_name}', skipping component API"
- )
- continue
-
- # Create RPC wrapper - use factory to properly capture method reference
- wrapper = self._create_component_api_wrapper(method)
-
- # Attach to driver instance
- setattr(self, method_name, wrapper)
-
- # Store in rpc_methods dict for backward compatibility
- self.rpc_methods[method_name] = wrapper
-
- # Track exposed method name for cleanup
- self._exposed_component_apis.add(method_name)
-
- self.logger.debug(f"Exposed component API as RPC: {method_name}")
-
- def _create_component_api_wrapper(self, component_method: Any) -> Any:
- """Create an RPC wrapper for a component API method.
-
- Args:
- component_method: The component method to wrap
-
- Returns:
- RPC-decorated wrapper function
- """
- import functools
-
- @rpc
- @functools.wraps(component_method)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- return component_method(*args, **kwargs)
-
- wrapper.__component_api_wrapper__ = True # type: ignore[attr-defined]
- return wrapper
-
- def _connect(self) -> None:
- """Connect to the manipulator hardware."""
- self.logger.info(f"Connecting to {self.name}...")
-
- # Connect via SDK
- if not self.sdk.connect(self.config):
- raise RuntimeError(f"Failed to connect to {self.name}")
-
- self.shared_state.is_connected = True
- self.logger.info(f"Successfully connected to {self.name}")
-
- # Get initial state
- self._update_joint_state()
- self._update_robot_state()
-
- def _update_joint_state(self) -> None:
- """Update joint state from hardware (high frequency - 100Hz).
-
- Reads joint positions, velocities, efforts and publishes to LCM immediately.
- """
- try:
- # Get joint state feedback
- positions = self.sdk.get_joint_positions()
- velocities = self.sdk.get_joint_velocities()
- efforts = self.sdk.get_joint_efforts()
-
- self.shared_state.update_joint_state(
- positions=positions, velocities=velocities, efforts=efforts
- )
-
- # Publish joint state immediately at control rate
- if self.joint_state and hasattr(self.joint_state, "publish"):
- joint_state_msg = JointState(
- ts=time.time(),
- frame_id="joint-state",
- name=self._joint_names, # Pre-allocated list (optimization)
- position=positions or [0.0] * self.capabilities.dof,
- velocity=velocities or [0.0] * self.capabilities.dof,
- effort=efforts or [0.0] * self.capabilities.dof,
- )
- self.joint_state.publish(joint_state_msg)
-
- except Exception as e:
- self.logger.error(f"Error updating joint state: {e}")
-
- def _update_robot_state(self) -> None:
- """Update robot state from hardware (low frequency - 10Hz).
-
- Reads robot mode, errors, warnings, optional states and publishes to LCM immediately.
- """
- try:
- # Get robot state (mode, errors, warnings)
- robot_state = self.sdk.get_robot_state()
- self.shared_state.update_robot_state(
- state=robot_state.get("state", 0),
- mode=robot_state.get("mode", 0),
- error_code=robot_state.get("error_code", 0),
- error_message=self.sdk.get_error_message(),
- )
-
- # Update status flags
- self.shared_state.is_moving = robot_state.get("is_moving", False)
- self.shared_state.is_enabled = self.sdk.are_servos_enabled()
-
- # Get optional states (cartesian, force/torque, gripper)
- if self.capabilities.has_cartesian_control:
- cart_pos = self.sdk.get_cartesian_position()
- if cart_pos:
- self.shared_state.cartesian_position = cart_pos
-
- if self.capabilities.has_force_torque:
- ft = self.sdk.get_force_torque()
- if ft:
- self.shared_state.force_torque = ft
-
- if self.capabilities.has_gripper:
- gripper_pos = self.sdk.get_gripper_position()
- if gripper_pos is not None:
- self.shared_state.gripper_position = gripper_pos
-
- # Publish robot state immediately at monitor rate
- if self.robot_state and hasattr(self.robot_state, "publish"):
- robot_state_msg = RobotState(
- state=self.shared_state.robot_state,
- mode=self.shared_state.control_mode,
- error_code=self.shared_state.error_code,
- warn_code=0,
- )
- self.robot_state.publish(robot_state_msg)
-
- # Publish force/torque if available
- if (
- self.ft_sensor
- and hasattr(self.ft_sensor, "publish")
- and self.capabilities.has_force_torque
- ):
- if self.shared_state.force_torque:
- ft_msg = WrenchStamped.from_force_torque_array(
- ft_data=self.shared_state.force_torque,
- frame_id="ft_sensor",
- ts=time.time(),
- )
- self.ft_sensor.publish(ft_msg)
-
- except Exception as e:
- self.logger.error(f"Error updating robot state: {e}")
- self.shared_state.update_robot_state(error_code=999, error_message=str(e))
-
- # ============= Threading =============
-
- @rpc
- def start(self) -> None:
- """Start all driver threads and subscribe to input topics."""
- super().start()
- self.logger.info(f"Starting {self.name} driver threads...")
-
- # Subscribe to input topics if they have transports
- try:
- if self.joint_position_command and hasattr(self.joint_position_command, "subscribe"):
- self.joint_position_command.subscribe(self._on_joint_position_command)
- self.logger.debug("Subscribed to joint_position_command")
- except (AttributeError, ValueError) as e:
- self.logger.debug(f"joint_position_command transport not configured: {e}")
-
- try:
- if self.joint_velocity_command and hasattr(self.joint_velocity_command, "subscribe"):
- self.joint_velocity_command.subscribe(self._on_joint_velocity_command)
- self.logger.debug("Subscribed to joint_velocity_command")
- except (AttributeError, ValueError) as e:
- self.logger.debug(f"joint_velocity_command transport not configured: {e}")
-
- self.threads = [
- Thread(target=self._control_loop_thread, name=f"{self.name}-ControlLoop", daemon=True),
- Thread(
- target=self._robot_state_monitor_thread,
- name=f"{self.name}-StateMonitor",
- daemon=True,
- ),
- ]
-
- for thread in self.threads:
- thread.start()
- self.logger.debug(f"Started thread: {thread.name}")
-
- self.logger.info(f"{self.name} driver started successfully")
-
- def _control_loop_thread(self) -> None:
- """Control loop: send commands AND read joint feedback (100Hz).
-
- This tight loop ensures synchronized command/feedback for real-time control.
- """
- self.logger.debug("Control loop thread started")
- period = 1.0 / self.control_rate
- next_time = time.perf_counter() + period # perf_counter for precise timing
-
- while not self.stop_event.is_set():
- try:
- # 1. Process all pending commands (non-blocking)
- while True:
- try:
- command = self.command_queue.get_nowait() # Non-blocking (optimization)
- self._process_command(command)
- except Empty:
- break # No more commands
-
- # 2. Read joint state feedback (critical for control)
- self._update_joint_state()
-
- except Exception as e:
- self.logger.error(f"Control loop error: {e}")
-
- # Rate control - maintain precise timing
- next_time += period
- sleep_time = next_time - time.perf_counter()
- if sleep_time > 0:
- time.sleep(sleep_time)
- else:
- # Fell behind - reset timing
- next_time = time.perf_counter() + period
- if sleep_time < -period:
- self.logger.warning(f"Control loop fell behind by {-sleep_time:.3f}s")
-
- self.logger.debug("Control loop thread stopped")
-
- def _robot_state_monitor_thread(self) -> None:
- """Monitor robot state: mode, errors, warnings (10-20Hz).
-
- Lower frequency monitoring for high-level planning and error handling.
- """
- self.logger.debug("Robot state monitor thread started")
- period = 1.0 / self.monitor_rate
- next_time = time.perf_counter() + period # perf_counter for precise timing
-
- while not self.stop_event.is_set():
- try:
- # Read robot state, mode, errors, optional states
- self._update_robot_state()
- except Exception as e:
- self.logger.error(f"Robot state monitor error: {e}")
-
- # Rate control
- next_time += period
- sleep_time = next_time - time.perf_counter()
- if sleep_time > 0:
- time.sleep(sleep_time)
- else:
- next_time = time.perf_counter() + period
-
- self.logger.debug("Robot state monitor thread stopped")
-
- def _process_command(self, command: Command) -> None:
- """Process a command from the queue.
-
- Args:
- command: Command to process
- """
- try:
- if command.type == "position":
- success = self.sdk.set_joint_positions(
- command.data["positions"],
- command.data.get("velocity", 1.0),
- command.data.get("acceleration", 1.0),
- command.data.get("wait", False),
- )
- if success:
- self.shared_state.target_positions = command.data["positions"]
-
- elif command.type == "velocity":
- success = self.sdk.set_joint_velocities(command.data["velocities"])
- if success:
- self.shared_state.target_velocities = command.data["velocities"]
-
- elif command.type == "effort":
- success = self.sdk.set_joint_efforts(command.data["efforts"])
- if success:
- self.shared_state.target_efforts = command.data["efforts"]
-
- elif command.type == "cartesian":
- success = self.sdk.set_cartesian_position(
- command.data["pose"],
- command.data.get("velocity", 1.0),
- command.data.get("acceleration", 1.0),
- command.data.get("wait", False),
- )
- if success:
- self.shared_state.target_cartesian_position = command.data["pose"]
-
- elif command.type == "stop":
- self.sdk.stop_motion()
-
- else:
- self.logger.warning(f"Unknown command type: {command.type}")
-
- except Exception as e:
- self.logger.error(f"Error processing command {command.type}: {e}")
-
- # ============= Input Callbacks =============
-
- def _on_joint_position_command(self, cmd_msg: JointCommand) -> None:
- """Callback when joint position command is received.
-
- Args:
- cmd_msg: JointCommand message containing positions
- """
- command = Command(
- type="position", data={"positions": list(cmd_msg.positions)}, timestamp=time.time()
- )
- try:
- self.command_queue.put_nowait(command)
- except:
- self.logger.warning("Command queue full, dropping position command")
-
- def _on_joint_velocity_command(self, cmd_msg: JointCommand) -> None:
- """Callback when joint velocity command is received.
-
- Args:
- cmd_msg: JointCommand message containing velocities
- """
- command = Command(
- type="velocity",
- data={"velocities": list(cmd_msg.positions)}, # JointCommand uses 'positions' field
- timestamp=time.time(),
- )
- try:
- self.command_queue.put_nowait(command)
- except:
- self.logger.warning("Command queue full, dropping velocity command")
-
- # ============= Lifecycle Management =============
-
- @rpc
- def stop(self) -> None:
- """Stop all threads and disconnect from hardware."""
- self.logger.info(f"Stopping {self.name} driver...")
-
- # Signal threads to stop
- self.stop_event.set()
-
- # Stop any ongoing motion
- try:
- self.sdk.stop_motion()
- except:
- pass
-
- # Wait for threads to stop
- for thread in self.threads:
- thread.join(timeout=2.0)
- if thread.is_alive():
- self.logger.warning(f"Thread {thread.name} did not stop cleanly")
-
- # Disconnect from hardware
- try:
- self.sdk.disconnect()
- except:
- pass
-
- self.shared_state.is_connected = False
- self.logger.info(f"{self.name} driver stopped")
-
- # Call Module's stop
- super().stop()
-
- def __del__(self) -> None:
- """Cleanup on deletion."""
- if self.shared_state.is_connected:
- self.stop()
-
- # ============= RPC Method Access =============
-
- def get_rpc_method(self, method_name: str) -> Any:
- """Get an RPC method by name.
-
- Args:
- method_name: Name of the RPC method
-
- Returns:
- The method if found, None otherwise
- """
- return self.rpc_methods.get(method_name)
-
- def list_rpc_methods(self) -> list[str]:
- """List all available RPC methods.
-
- Returns:
- List of RPC method names
- """
- return list(self.rpc_methods.keys())
-
- # ============= Component Access =============
-
- def get_component(self, component_type: type[Any]) -> Any:
- """Get a component by type.
-
- Args:
- component_type: Type of component to find
-
- Returns:
- The component if found, None otherwise
- """
- for component in self.components:
- if isinstance(component, component_type):
- return component
- return None
-
- def add_component(self, component: Any) -> None:
- """Add a component at runtime.
-
- Args:
- component: Component instance to add
- """
- self.components.append(component)
- self._initialize_components()
- self._auto_expose_component_apis()
-
- def remove_component(self, component: Any) -> None:
- """Remove a component at runtime.
-
- Args:
- component: Component instance to remove
- """
- if component in self.components:
- self.components.remove(component)
- # Clean up old exposed methods and re-expose for remaining components
- self._cleanup_exposed_component_apis()
- self._auto_expose_component_apis()
-
- def _cleanup_exposed_component_apis(self) -> None:
- """Remove all auto-exposed component API methods from the driver."""
- for method_name in self._exposed_component_apis:
- if hasattr(self, method_name):
- delattr(self, method_name)
- self._exposed_component_apis.clear()
- self.rpc_methods.clear()
diff --git a/dimos/hardware/manipulators/base/sdk_interface.py b/dimos/hardware/manipulators/base/sdk_interface.py
deleted file mode 100644
index f20d35bd50..0000000000
--- a/dimos/hardware/manipulators/base/sdk_interface.py
+++ /dev/null
@@ -1,471 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Base SDK interface that all manipulator SDK wrappers must implement.
-
-This interface defines the standard methods and units that all SDK wrappers
-must provide, ensuring consistent behavior across different manipulator types.
-
-Standard Units:
-- Angles: radians
-- Angular velocity: rad/s
-- Linear position: meters
-- Linear velocity: m/s
-- Force: Newtons
-- Torque: Nm
-- Time: seconds
-"""
-
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from typing import Any
-
-
-@dataclass
-class ManipulatorInfo:
- """Information about the manipulator."""
-
- vendor: str
- model: str
- dof: int
- firmware_version: str | None = None
- serial_number: str | None = None
-
-
-class BaseManipulatorSDK(ABC):
- """Abstract base class for manipulator SDK wrappers.
-
- All SDK wrappers must implement this interface to ensure compatibility
- with the standard components. Methods should handle unit conversions
- internally to always work with standard units.
- """
-
- # ============= Connection Management =============
-
- @abstractmethod
- def connect(self, config: dict[str, Any]) -> bool:
- """Establish connection to the manipulator.
-
- Args:
- config: Configuration dict with connection parameters
- (e.g., ip, port, can_interface, etc.)
-
- Returns:
- True if connection successful, False otherwise
- """
- pass
-
- @abstractmethod
- def disconnect(self) -> None:
- """Disconnect from the manipulator.
-
- Should cleanly close all connections and free resources.
- """
- pass
-
- @abstractmethod
- def is_connected(self) -> bool:
- """Check if currently connected to the manipulator.
-
- Returns:
- True if connected, False otherwise
- """
- pass
-
- # ============= Joint State Query =============
-
- @abstractmethod
- def get_joint_positions(self) -> list[float]:
- """Get current joint positions.
-
- Returns:
- Joint positions in RADIANS
- """
- pass
-
- @abstractmethod
- def get_joint_velocities(self) -> list[float]:
- """Get current joint velocities.
-
- Returns:
- Joint velocities in RAD/S
- """
- pass
-
- @abstractmethod
- def get_joint_efforts(self) -> list[float]:
- """Get current joint efforts/torques.
-
- Returns:
- Joint efforts in Nm (torque) or N (force)
- """
- pass
-
- # ============= Joint Motion Control =============
-
- @abstractmethod
- def set_joint_positions(
- self,
- positions: list[float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> bool:
- """Move joints to target positions.
-
- Args:
- positions: Target positions in RADIANS
- velocity: Max velocity as fraction of maximum (0-1)
- acceleration: Max acceleration as fraction of maximum (0-1)
- wait: If True, block until motion completes
-
- Returns:
- True if command accepted, False otherwise
- """
- pass
-
- @abstractmethod
- def set_joint_velocities(self, velocities: list[float]) -> bool:
- """Set joint velocity targets.
-
- Args:
- velocities: Target velocities in RAD/S
-
- Returns:
- True if command accepted, False otherwise
- """
- pass
-
- @abstractmethod
- def set_joint_efforts(self, efforts: list[float]) -> bool:
- """Set joint effort/torque targets.
-
- Args:
- efforts: Target efforts in Nm (torque) or N (force)
-
- Returns:
- True if command accepted, False otherwise
- """
- pass
-
- @abstractmethod
- def stop_motion(self) -> bool:
- """Stop all ongoing motion immediately.
-
- Returns:
- True if stop successful, False otherwise
- """
- pass
-
- # ============= Servo Control =============
-
- @abstractmethod
- def enable_servos(self) -> bool:
- """Enable motor control (servos/brakes released).
-
- Returns:
- True if servos enabled, False otherwise
- """
- pass
-
- @abstractmethod
- def disable_servos(self) -> bool:
- """Disable motor control (servos/brakes engaged).
-
- Returns:
- True if servos disabled, False otherwise
- """
- pass
-
- @abstractmethod
- def are_servos_enabled(self) -> bool:
- """Check if servos are currently enabled.
-
- Returns:
- True if enabled, False if disabled
- """
- pass
-
- # ============= System State =============
-
- @abstractmethod
- def get_robot_state(self) -> dict[str, Any]:
- """Get current robot state information.
-
- Returns:
- Dict with at least these keys:
- - 'state': int (0=idle, 1=moving, 2=error, 3=e-stop)
- - 'mode': int (0=position, 1=velocity, 2=torque)
- - 'error_code': int (0 = no error)
- - 'is_moving': bool
- """
- pass
-
- @abstractmethod
- def get_error_code(self) -> int:
- """Get current error code.
-
- Returns:
- Error code (0 = no error)
- """
- pass
-
- @abstractmethod
- def get_error_message(self) -> str:
- """Get human-readable error message.
-
- Returns:
- Error message string (empty if no error)
- """
- pass
-
- @abstractmethod
- def clear_errors(self) -> bool:
- """Clear any error states.
-
- Returns:
- True if errors cleared, False otherwise
- """
- pass
-
- @abstractmethod
- def emergency_stop(self) -> bool:
- """Execute emergency stop.
-
- Returns:
- True if e-stop executed, False otherwise
- """
- pass
-
- # ============= Information =============
-
- @abstractmethod
- def get_info(self) -> ManipulatorInfo:
- """Get manipulator information.
-
- Returns:
- ManipulatorInfo object with vendor, model, DOF, etc.
- """
- pass
-
- @abstractmethod
- def get_joint_limits(self) -> tuple[list[float], list[float]]:
- """Get joint position limits.
-
- Returns:
- Tuple of (lower_limits, upper_limits) in RADIANS
- """
- pass
-
- @abstractmethod
- def get_velocity_limits(self) -> list[float]:
- """Get joint velocity limits.
-
- Returns:
- Maximum velocities in RAD/S
- """
- pass
-
- @abstractmethod
- def get_acceleration_limits(self) -> list[float]:
- """Get joint acceleration limits.
-
- Returns:
- Maximum accelerations in RAD/S²
- """
- pass
-
- # ============= Optional Methods (Override if Supported) =============
- # These have default implementations that indicate feature not available
-
- def get_cartesian_position(self) -> dict[str, float] | None:
- """Get current end-effector pose.
-
- Returns:
- Dict with keys: x, y, z (meters), roll, pitch, yaw (radians)
- None if not supported
- """
- return None
-
- def set_cartesian_position(
- self,
- pose: dict[str, float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> bool:
- """Move end-effector to target pose.
-
- Args:
- pose: Target pose with keys: x, y, z (meters), roll, pitch, yaw (radians)
- velocity: Max velocity as fraction (0-1)
- acceleration: Max acceleration as fraction (0-1)
- wait: If True, block until motion completes
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def get_cartesian_velocity(self) -> dict[str, float] | None:
- """Get current end-effector velocity.
-
- Returns:
- Dict with keys: vx, vy, vz (m/s), wx, wy, wz (rad/s)
- None if not supported
- """
- return None
-
- def set_cartesian_velocity(self, twist: dict[str, float]) -> bool:
- """Set end-effector velocity.
-
- Args:
- twist: Velocity with keys: vx, vy, vz (m/s), wx, wy, wz (rad/s)
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def get_force_torque(self) -> list[float] | None:
- """Get force/torque sensor reading.
-
- Returns:
- List of [fx, fy, fz (N), tx, ty, tz (Nm)]
- None if not supported
- """
- return None
-
- def zero_force_torque(self) -> bool:
- """Zero the force/torque sensor.
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def set_impedance_parameters(self, stiffness: list[float], damping: list[float]) -> bool:
- """Set impedance control parameters.
-
- Args:
- stiffness: Stiffness values [x, y, z, rx, ry, rz]
- damping: Damping values [x, y, z, rx, ry, rz]
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def get_digital_inputs(self) -> dict[str, bool] | None:
- """Get digital input states.
-
- Returns:
- Dict of input_id: bool
- None if not supported
- """
- return None
-
- def set_digital_outputs(self, outputs: dict[str, bool]) -> bool:
- """Set digital output states.
-
- Args:
- outputs: Dict of output_id: bool
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def get_analog_inputs(self) -> dict[str, float] | None:
- """Get analog input values.
-
- Returns:
- Dict of input_id: float
- None if not supported
- """
- return None
-
- def set_analog_outputs(self, outputs: dict[str, float]) -> bool:
- """Set analog output values.
-
- Args:
- outputs: Dict of output_id: float
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def execute_trajectory(self, trajectory: list[dict[str, Any]], wait: bool = True) -> bool:
- """Execute a joint trajectory.
-
- Args:
- trajectory: List of waypoints, each with:
- - 'positions': list[float] in radians
- - 'velocities': Optional list[float] in rad/s
- - 'time': float seconds from start
- wait: If True, block until trajectory completes
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def stop_trajectory(self) -> bool:
- """Stop any executing trajectory.
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def get_gripper_position(self) -> float | None:
- """Get gripper position.
-
- Returns:
- Position in meters (0=closed, max=fully open)
- None if no gripper
- """
- return None
-
- def set_gripper_position(self, position: float, force: float = 1.0) -> bool:
- """Set gripper position.
-
- Args:
- position: Target position in meters
- force: Gripping force as fraction (0-1)
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def set_control_mode(self, mode: str) -> bool:
- """Set control mode.
-
- Args:
- mode: One of 'position', 'velocity', 'torque', 'impedance'
-
- Returns:
- False (not supported by default)
- """
- return False
-
- def get_control_mode(self) -> str | None:
- """Get current control mode.
-
- Returns:
- Current mode string or None if not supported
- """
- return None
diff --git a/dimos/hardware/manipulators/base/spec.py b/dimos/hardware/manipulators/base/spec.py
deleted file mode 100644
index 8a0722cf09..0000000000
--- a/dimos/hardware/manipulators/base/spec.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from typing import Any, Protocol
-
-from dimos.core import In, Out
-from dimos.msgs.geometry_msgs import WrenchStamped
-from dimos.msgs.sensor_msgs import JointCommand, JointState
-
-
-@dataclass
-class RobotState:
- """Universal robot state compatible with all manipulators."""
-
- # Core state fields (all manipulators must provide these)
- state: int = 0 # 0: idle, 1: moving, 2: error, 3: e-stop
- mode: int = 0 # 0: position, 1: velocity, 2: torque, 3: impedance
- error_code: int = 0 # Standardized error codes across all arms
- warn_code: int = 0 # Standardized warning codes
-
- # Extended state (optional, arm-specific)
- is_connected: bool = False
- is_enabled: bool = False
- is_moving: bool = False
- is_collision: bool = False
-
- # Vendor-specific data (if needed)
- vendor_data: dict[str, Any] | None = None
-
-
-@dataclass
-class ManipulatorCapabilities:
- """Describes what a manipulator can do."""
-
- dof: int # Degrees of freedom
- has_gripper: bool = False
- has_force_torque: bool = False
- has_impedance_control: bool = False
- has_cartesian_control: bool = False
- max_joint_velocity: list[float] | None = None # rad/s
- max_joint_acceleration: list[float] | None = None # rad/s²
- joint_limits_lower: list[float] | None = None # rad
- joint_limits_upper: list[float] | None = None # rad
- payload_mass: float = 0.0 # kg
- reach: float = 0.0 # meters
-
-
-class ManipulatorDriverSpec(Protocol):
- """Universal protocol specification for ALL manipulator drivers.
-
- This defines the standard interface that every manipulator driver
- must implement, regardless of the underlying hardware (XArm, Piper,
- UR, Franka, etc.).
-
- ## Component-Based Architecture
-
- Drivers use a **component-based architecture** where functionality is provided
- by composable components:
-
- - **StandardMotionComponent**: Joint/cartesian motion, trajectory execution
- - **StandardServoComponent**: Servo control, modes, emergency stop, error handling
- - **StandardStatusComponent**: State monitoring, capabilities, diagnostics
-
- RPC methods are provided by components and registered with the driver.
- Access them via:
-
- ```python
- # Method 1: Via component (direct access)
- motion = driver.get_component(StandardMotionComponent)
- motion.rpc_move_joint(positions=[0, 0, 0, 0, 0, 0])
-
- # Method 2: Via driver's RPC registry
- move_fn = driver.get_rpc_method('rpc_move_joint')
- move_fn(positions=[0, 0, 0, 0, 0, 0])
-
- # Method 3: Via blueprints (recommended - automatic routing)
- # Commands sent to input topics are automatically routed to components
- driver.joint_position_command.publish(JointCommand(positions=[0, 0, 0, 0, 0, 0]))
- ```
-
- ## Required Components
-
- Every driver must include these standard components:
- - `StandardMotionComponent` - Provides motion control RPC methods
- - `StandardServoComponent` - Provides servo control RPC methods
- - `StandardStatusComponent` - Provides status monitoring RPC methods
-
- ## Available RPC Methods (via Components)
-
- ### Motion Control (StandardMotionComponent)
- - `rpc_move_joint()` - Move to joint positions
- - `rpc_move_joint_velocity()` - Set joint velocities
- - `rpc_move_joint_effort()` - Set joint efforts (optional)
- - `rpc_stop_motion()` - Stop all motion
- - `rpc_get_joint_state()` - Get current joint state
- - `rpc_get_joint_limits()` - Get joint limits
- - `rpc_move_cartesian()` - Cartesian motion (optional)
- - `rpc_execute_trajectory()` - Execute trajectory (optional)
-
- ### Servo Control (StandardServoComponent)
- - `rpc_enable_servo()` - Enable motor control
- - `rpc_disable_servo()` - Disable motor control
- - `rpc_set_control_mode()` - Set control mode
- - `rpc_emergency_stop()` - Execute emergency stop
- - `rpc_clear_errors()` - Clear error states
- - `rpc_home_robot()` - Home the robot
-
- ### Status Monitoring (StandardStatusComponent)
- - `rpc_get_robot_state()` - Get robot state
- - `rpc_get_capabilities()` - Get capabilities
- - `rpc_get_system_info()` - Get system information
- - `rpc_check_connection()` - Check connection status
-
- ## Standardized Units
-
- All units are standardized:
- - Angles: radians
- - Angular velocity: rad/s
- - Linear position: meters
- - Linear velocity: m/s
- - Force: Newtons
- - Torque: Nm
- - Time: seconds
- """
-
- # ============= Capabilities Declaration =============
- capabilities: ManipulatorCapabilities
-
- # ============= Input Topics (Commands) =============
- # Core control inputs (all manipulators must support these)
- joint_position_command: In[JointCommand] # Target joint positions (rad)
- joint_velocity_command: In[JointCommand] # Target joint velocities (rad/s)
-
- # ============= Output Topics (Feedback) =============
- # Core feedback (all manipulators must provide these)
- joint_state: Out[JointState] # Current positions, velocities, efforts
- robot_state: Out[RobotState] # System state and health
-
- # Optional feedback (capability-dependent)
- ft_sensor: Out[WrenchStamped] | None # Force/torque sensor data
-
- # ============= Component Access =============
- def get_component(self, component_type: type) -> Any:
- """Get a component by type.
-
- Args:
- component_type: Type of component to retrieve
-
- Returns:
- Component instance if found, None otherwise
-
- Example:
- motion = driver.get_component(StandardMotionComponent)
- motion.rpc_move_joint([0, 0, 0, 0, 0, 0])
- """
- pass
-
- def get_rpc_method(self, method_name: str) -> Any:
- """Get an RPC method by name.
-
- Args:
- method_name: Name of the RPC method (e.g., 'rpc_move_joint')
-
- Returns:
- Callable method if found, None otherwise
-
- Example:
- move_fn = driver.get_rpc_method('rpc_move_joint')
- result = move_fn(positions=[0, 0, 0, 0, 0, 0])
- """
- ...
-
- def list_rpc_methods(self) -> list[str]:
- """List all available RPC methods from all components.
-
- Returns:
- List of RPC method names
-
- Example:
- methods = driver.list_rpc_methods()
- # ['rpc_move_joint', 'rpc_enable_servo', 'rpc_get_robot_state', ...]
- """
- ...
diff --git a/dimos/hardware/manipulators/base/tests/conftest.py b/dimos/hardware/manipulators/base/tests/conftest.py
deleted file mode 100644
index d3e6a4c66d..0000000000
--- a/dimos/hardware/manipulators/base/tests/conftest.py
+++ /dev/null
@@ -1,362 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Pytest fixtures and mocks for manipulator driver tests.
-
-This module contains MockSDK which implements BaseManipulatorSDK with controllable
-behavior for testing driver logic without requiring hardware.
-
-Features:
-- Configurable initial state (positions, DOF, vendor, model)
-- Call tracking for verification
-- Configurable error injection
-- Simulated behavior (e.g., position updates)
-"""
-
-from dataclasses import dataclass, field
-import math
-
-import pytest
-
-from ..sdk_interface import BaseManipulatorSDK, ManipulatorInfo
-
-
-@dataclass
-class MockSDKConfig:
- """Configuration for MockSDK behavior."""
-
- dof: int = 6
- vendor: str = "Mock"
- model: str = "TestArm"
- initial_positions: list[float] | None = None
- initial_velocities: list[float] | None = None
- initial_efforts: list[float] | None = None
-
- # Error injection
- connect_fails: bool = False
- enable_fails: bool = False
- motion_fails: bool = False
- error_code: int = 0
-
- # Behavior options
- simulate_motion: bool = False # If True, set_joint_positions updates internal state
-
-
-@dataclass
-class CallRecord:
- """Record of a method call for verification."""
-
- method: str
- args: tuple = field(default_factory=tuple)
- kwargs: dict = field(default_factory=dict)
-
-
-class MockSDK(BaseManipulatorSDK):
- """Mock SDK for unit testing. Implements BaseManipulatorSDK interface.
-
- Usage:
- # Basic usage
- mock = MockSDK()
- driver = create_driver_with_sdk(mock)
- driver.enable_servo()
- assert mock.enable_servos_called
-
- # With custom config
- config = MockSDKConfig(dof=7, connect_fails=True)
- mock = MockSDK(config=config)
-
- # With initial positions
- mock = MockSDK(positions=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
-
- # Verify calls
- mock.set_joint_positions([0.1] * 6)
- assert mock.was_called("set_joint_positions")
- assert mock.call_count("set_joint_positions") == 1
- """
-
- def __init__(
- self,
- config: MockSDKConfig | None = None,
- *,
- dof: int = 6,
- vendor: str = "Mock",
- model: str = "TestArm",
- positions: list[float] | None = None,
- ):
- """Initialize MockSDK.
-
- Args:
- config: Full configuration object (takes precedence)
- dof: Degrees of freedom (ignored if config provided)
- vendor: Vendor name (ignored if config provided)
- model: Model name (ignored if config provided)
- positions: Initial joint positions (ignored if config provided)
- """
- if config is None:
- config = MockSDKConfig(
- dof=dof,
- vendor=vendor,
- model=model,
- initial_positions=positions,
- )
-
- self._config = config
- self._dof = config.dof
- self._vendor = config.vendor
- self._model = config.model
-
- # State
- self._connected = False
- self._servos_enabled = False
- self._positions = list(config.initial_positions or [0.0] * self._dof)
- self._velocities = list(config.initial_velocities or [0.0] * self._dof)
- self._efforts = list(config.initial_efforts or [0.0] * self._dof)
- self._mode = 0
- self._state = 0
- self._error_code = config.error_code
-
- # Call tracking
- self._calls: list[CallRecord] = []
-
- # Convenience flags for simple assertions
- self.connect_called = False
- self.disconnect_called = False
- self.enable_servos_called = False
- self.disable_servos_called = False
- self.set_joint_positions_called = False
- self.set_joint_velocities_called = False
- self.stop_motion_called = False
- self.emergency_stop_called = False
- self.clear_errors_called = False
-
- def _record_call(self, method: str, *args, **kwargs):
- """Record a method call."""
- self._calls.append(CallRecord(method=method, args=args, kwargs=kwargs))
-
- def was_called(self, method: str) -> bool:
- """Check if a method was called."""
- return any(c.method == method for c in self._calls)
-
- def call_count(self, method: str) -> int:
- """Get the number of times a method was called."""
- return sum(1 for c in self._calls if c.method == method)
-
- def get_calls(self, method: str) -> list[CallRecord]:
- """Get all calls to a specific method."""
- return [c for c in self._calls if c.method == method]
-
- def get_last_call(self, method: str) -> CallRecord | None:
- """Get the last call to a specific method."""
- calls = self.get_calls(method)
- return calls[-1] if calls else None
-
- def reset_calls(self):
- """Reset call tracking."""
- self._calls.clear()
- self.connect_called = False
- self.disconnect_called = False
- self.enable_servos_called = False
- self.disable_servos_called = False
- self.set_joint_positions_called = False
- self.set_joint_velocities_called = False
- self.stop_motion_called = False
- self.emergency_stop_called = False
- self.clear_errors_called = False
-
- # ============= State Manipulation (for test setup) =============
-
- def set_positions(self, positions: list[float]):
- """Set internal positions (test helper)."""
- self._positions = list(positions)
-
- def set_error(self, code: int, message: str = ""):
- """Inject an error state (test helper)."""
- self._error_code = code
-
- def set_enabled(self, enabled: bool):
- """Set servo enabled state (test helper)."""
- self._servos_enabled = enabled
-
- # ============= BaseManipulatorSDK Implementation =============
-
- def connect(self, config: dict) -> bool:
- self._record_call("connect", config)
- self.connect_called = True
-
- if self._config.connect_fails:
- return False
-
- self._connected = True
- return True
-
- def disconnect(self) -> None:
- self._record_call("disconnect")
- self.disconnect_called = True
- self._connected = False
-
- def is_connected(self) -> bool:
- self._record_call("is_connected")
- return self._connected
-
- def get_joint_positions(self) -> list[float]:
- self._record_call("get_joint_positions")
- return self._positions.copy()
-
- def get_joint_velocities(self) -> list[float]:
- self._record_call("get_joint_velocities")
- return self._velocities.copy()
-
- def get_joint_efforts(self) -> list[float]:
- self._record_call("get_joint_efforts")
- return self._efforts.copy()
-
- def set_joint_positions(
- self,
- positions: list[float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> bool:
- self._record_call(
- "set_joint_positions",
- positions,
- velocity=velocity,
- acceleration=acceleration,
- wait=wait,
- )
- self.set_joint_positions_called = True
-
- if self._config.motion_fails:
- return False
-
- if not self._servos_enabled:
- return False
-
- if self._config.simulate_motion:
- self._positions = list(positions)
-
- return True
-
- def set_joint_velocities(self, velocities: list[float]) -> bool:
- self._record_call("set_joint_velocities", velocities)
- self.set_joint_velocities_called = True
-
- if self._config.motion_fails:
- return False
-
- if not self._servos_enabled:
- return False
-
- self._velocities = list(velocities)
- return True
-
- def set_joint_efforts(self, efforts: list[float]) -> bool:
- self._record_call("set_joint_efforts", efforts)
- return False # Not supported in mock
-
- def stop_motion(self) -> bool:
- self._record_call("stop_motion")
- self.stop_motion_called = True
- self._velocities = [0.0] * self._dof
- return True
-
- def enable_servos(self) -> bool:
- self._record_call("enable_servos")
- self.enable_servos_called = True
-
- if self._config.enable_fails:
- return False
-
- self._servos_enabled = True
- return True
-
- def disable_servos(self) -> bool:
- self._record_call("disable_servos")
- self.disable_servos_called = True
- self._servos_enabled = False
- return True
-
- def are_servos_enabled(self) -> bool:
- self._record_call("are_servos_enabled")
- return self._servos_enabled
-
- def get_robot_state(self) -> dict:
- self._record_call("get_robot_state")
- return {
- "state": self._state,
- "mode": self._mode,
- "error_code": self._error_code,
- "is_moving": any(v != 0 for v in self._velocities),
- }
-
- def get_error_code(self) -> int:
- self._record_call("get_error_code")
- return self._error_code
-
- def get_error_message(self) -> str:
- self._record_call("get_error_message")
- return "" if self._error_code == 0 else f"Mock error {self._error_code}"
-
- def clear_errors(self) -> bool:
- self._record_call("clear_errors")
- self.clear_errors_called = True
- self._error_code = 0
- return True
-
- def emergency_stop(self) -> bool:
- self._record_call("emergency_stop")
- self.emergency_stop_called = True
- self._velocities = [0.0] * self._dof
- self._servos_enabled = False
- return True
-
- def get_info(self) -> ManipulatorInfo:
- self._record_call("get_info")
- return ManipulatorInfo(
- vendor=self._vendor,
- model=f"{self._model} (Mock)",
- dof=self._dof,
- firmware_version="mock-1.0.0",
- serial_number="MOCK-001",
- )
-
- def get_joint_limits(self) -> tuple[list[float], list[float]]:
- self._record_call("get_joint_limits")
- lower = [-2 * math.pi] * self._dof
- upper = [2 * math.pi] * self._dof
- return lower, upper
-
- def get_velocity_limits(self) -> list[float]:
- self._record_call("get_velocity_limits")
- return [math.pi] * self._dof
-
- def get_acceleration_limits(self) -> list[float]:
- self._record_call("get_acceleration_limits")
- return [math.pi * 2] * self._dof
-
-
-# ============= Pytest Fixtures =============
-
-
-@pytest.fixture
-def mock_sdk():
- """Create a basic MockSDK."""
- return MockSDK(dof=6)
-
-
-@pytest.fixture
-def mock_sdk_with_positions():
- """Create MockSDK with initial positions."""
- positions = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
- return MockSDK(positions=positions)
diff --git a/dimos/hardware/manipulators/base/tests/test_driver_unit.py b/dimos/hardware/manipulators/base/tests/test_driver_unit.py
deleted file mode 100644
index b305d8cd15..0000000000
--- a/dimos/hardware/manipulators/base/tests/test_driver_unit.py
+++ /dev/null
@@ -1,577 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Unit tests for BaseManipulatorDriver.
-
-These tests use MockSDK to test driver logic in isolation without hardware.
-Run with: pytest dimos/hardware/manipulators/base/tests/test_driver_unit.py -v
-"""
-
-import math
-import time
-
-import pytest
-
-from ..components import (
- StandardMotionComponent,
- StandardServoComponent,
- StandardStatusComponent,
-)
-from ..driver import BaseManipulatorDriver
-from .conftest import MockSDK, MockSDKConfig
-
-# =============================================================================
-# Fixtures
-# =============================================================================
-# Note: mock_sdk and mock_sdk_with_positions fixtures are defined in conftest.py
-
-
-@pytest.fixture
-def standard_components():
- """Create standard component set."""
- return [
- StandardMotionComponent(),
- StandardServoComponent(),
- StandardStatusComponent(),
- ]
-
-
-@pytest.fixture
-def driver(mock_sdk, standard_components):
- """Create a driver with MockSDK and standard components."""
- config = {"dof": 6}
- driver = BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config=config,
- name="TestDriver",
- )
- yield driver
- # Cleanup - stop driver if running
- try:
- driver.stop()
- except Exception:
- pass
-
-
-@pytest.fixture
-def started_driver(driver):
- """Create and start a driver."""
- driver.start()
- time.sleep(0.05) # Allow threads to start
- yield driver
-
-
-# =============================================================================
-# Connection Tests
-# =============================================================================
-
-
-class TestConnection:
- """Tests for driver connection behavior."""
-
- def test_driver_connects_on_init(self, mock_sdk, standard_components):
- """Driver should connect to SDK during initialization."""
- config = {"dof": 6}
- driver = BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config=config,
- name="TestDriver",
- )
-
- assert mock_sdk.connect_called
- assert mock_sdk.is_connected()
- assert driver.shared_state.is_connected
-
- driver.stop()
-
- @pytest.mark.skip(
- reason="Driver init failure leaks LCM threads - needs cleanup fix in Module base class"
- )
- def test_connection_failure_raises(self, standard_components):
- """Driver should raise if SDK connection fails."""
- config_fail = MockSDKConfig(connect_fails=True)
- mock_sdk = MockSDK(config=config_fail)
-
- with pytest.raises(RuntimeError, match="Failed to connect"):
- BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config={"dof": 6},
- name="TestDriver",
- )
-
- def test_disconnect_on_stop(self, started_driver, mock_sdk):
- """Driver should disconnect SDK on stop."""
- started_driver.stop()
-
- assert mock_sdk.disconnect_called
- assert not started_driver.shared_state.is_connected
-
-
-# =============================================================================
-# Joint State Tests
-# =============================================================================
-
-
-class TestJointState:
- """Tests for joint state reading."""
-
- def test_get_joint_state_returns_positions(self, driver):
- """get_joint_state should return current positions."""
- result = driver.get_joint_state()
-
- assert result["success"] is True
- assert len(result["positions"]) == 6
- assert len(result["velocities"]) == 6
- assert len(result["efforts"]) == 6
-
- def test_get_joint_state_with_custom_positions(self, standard_components):
- """get_joint_state should return SDK positions."""
- expected_positions = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
- mock_sdk = MockSDK(positions=expected_positions)
-
- driver = BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config={"dof": 6},
- name="TestDriver",
- )
-
- result = driver.get_joint_state()
-
- assert result["positions"] == expected_positions
-
- driver.stop()
-
- def test_shared_state_updated_on_joint_read(self, driver):
- """Shared state should be updated when reading joints."""
- # Manually trigger joint state update
- driver._update_joint_state()
-
- assert driver.shared_state.joint_positions is not None
- assert len(driver.shared_state.joint_positions) == 6
-
-
-# =============================================================================
-# Servo Control Tests
-# =============================================================================
-
-
-class TestServoControl:
- """Tests for servo enable/disable."""
-
- def test_enable_servo_calls_sdk(self, driver, mock_sdk):
- """enable_servo should call SDK's enable_servos."""
- result = driver.enable_servo()
-
- assert result["success"] is True
- assert mock_sdk.enable_servos_called
-
- def test_enable_servo_updates_shared_state(self, driver):
- """enable_servo should update shared state."""
- driver.enable_servo()
-
- # Trigger state update to sync
- driver._update_robot_state()
-
- assert driver.shared_state.is_enabled is True
-
- def test_disable_servo_calls_sdk(self, driver, mock_sdk):
- """disable_servo should call SDK's disable_servos."""
- driver.enable_servo() # Enable first
- result = driver.disable_servo()
-
- assert result["success"] is True
- assert mock_sdk.disable_servos_called
-
- def test_enable_fails_with_error(self, standard_components):
- """enable_servo should return failure when SDK fails."""
- config = MockSDKConfig(enable_fails=True)
- mock_sdk = MockSDK(config=config)
-
- driver = BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config={"dof": 6},
- name="TestDriver",
- )
-
- result = driver.enable_servo()
-
- assert result["success"] is False
-
- driver.stop()
-
-
-# =============================================================================
-# Motion Control Tests
-# =============================================================================
-
-
-class TestMotionControl:
- """Tests for motion commands."""
-
- def test_move_joint_blocking_calls_sdk(self, driver, mock_sdk):
- """move_joint with wait=True should call SDK directly."""
- # Enable servos first (required for motion)
- driver.enable_servo()
-
- target = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
- # Use wait=True to bypass queue and call SDK directly
- result = driver.move_joint(target, velocity=0.5, wait=True)
-
- assert result["success"] is True
- assert mock_sdk.set_joint_positions_called
-
- # Verify arguments
- call = mock_sdk.get_last_call("set_joint_positions")
- assert call is not None
- assert list(call.args[0]) == target
- assert call.kwargs["velocity"] == 0.5
-
- def test_move_joint_async_queues_command(self, driver, mock_sdk):
- """move_joint with wait=False should queue command."""
- # Enable servos first
- driver.enable_servo()
-
- target = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
- # Default wait=False queues command
- result = driver.move_joint(target, velocity=0.5)
-
- assert result["success"] is True
- assert result.get("queued") is True
- # SDK not called yet (command is in queue)
- assert not mock_sdk.set_joint_positions_called
- # But command is in the queue
- assert not driver.command_queue.empty()
-
- def test_move_joint_fails_without_enable(self, driver, mock_sdk):
- """move_joint should fail if servos not enabled (blocking mode)."""
- target = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
- # Use wait=True to test synchronous failure
- result = driver.move_joint(target, wait=True)
-
- assert result["success"] is False
-
- def test_move_joint_with_simulated_motion(self, standard_components):
- """With simulate_motion, positions should update (blocking mode)."""
- config = MockSDKConfig(simulate_motion=True)
- mock_sdk = MockSDK(config=config)
-
- driver = BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config={"dof": 6},
- name="TestDriver",
- )
-
- driver.enable_servo()
- target = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
- # Use wait=True to execute directly
- driver.move_joint(target, wait=True)
-
- # Check SDK internal state updated
- assert mock_sdk.get_joint_positions() == target
-
- driver.stop()
-
- def test_stop_motion_calls_sdk(self, driver, mock_sdk):
- """stop_motion should call SDK's stop_motion."""
- result = driver.stop_motion()
-
- # stop_motion may return success=False if not moving, but should not error
- assert result is not None
- assert mock_sdk.stop_motion_called
-
- def test_process_command_calls_sdk(self, driver, mock_sdk):
- """_process_command should execute queued commands."""
- from ..driver import Command
-
- driver.enable_servo()
-
- # Create a position command directly
- command = Command(
- type="position",
- data={"positions": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "velocity": 0.5},
- )
-
- # Process it directly
- driver._process_command(command)
-
- assert mock_sdk.set_joint_positions_called
-
-
-# =============================================================================
-# Robot State Tests
-# =============================================================================
-
-
-class TestRobotState:
- """Tests for robot state reading."""
-
- def test_get_robot_state_returns_state(self, driver):
- """get_robot_state should return state info."""
- result = driver.get_robot_state()
-
- assert result["success"] is True
- assert "state" in result
- assert "mode" in result
- assert "error_code" in result
-
- def test_get_robot_state_with_error(self, standard_components):
- """get_robot_state should report errors from SDK."""
- config = MockSDKConfig(error_code=42)
- mock_sdk = MockSDK(config=config)
-
- driver = BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config={"dof": 6},
- name="TestDriver",
- )
-
- result = driver.get_robot_state()
-
- assert result["error_code"] == 42
-
- driver.stop()
-
- def test_clear_errors_calls_sdk(self, driver, mock_sdk):
- """clear_errors should call SDK's clear_errors."""
- result = driver.clear_errors()
-
- assert result["success"] is True
- assert mock_sdk.clear_errors_called
-
-
-# =============================================================================
-# Joint Limits Tests
-# =============================================================================
-
-
-class TestJointLimits:
- """Tests for joint limit queries."""
-
- def test_get_joint_limits_returns_limits(self, driver):
- """get_joint_limits should return lower and upper limits."""
- result = driver.get_joint_limits()
-
- assert result["success"] is True
- assert len(result["lower"]) == 6
- assert len(result["upper"]) == 6
-
- def test_joint_limits_are_reasonable(self, driver):
- """Joint limits should be reasonable values."""
- result = driver.get_joint_limits()
-
- for lower, upper in zip(result["lower"], result["upper"], strict=False):
- assert lower < upper
- assert lower >= -2 * math.pi
- assert upper <= 2 * math.pi
-
-
-# =============================================================================
-# Capabilities Tests
-# =============================================================================
-
-
-class TestCapabilities:
- """Tests for driver capabilities."""
-
- def test_capabilities_from_sdk(self, driver):
- """Driver should get capabilities from SDK."""
- assert driver.capabilities.dof == 6
- assert len(driver.capabilities.max_joint_velocity) == 6
- assert len(driver.capabilities.joint_limits_lower) == 6
-
- def test_capabilities_with_different_dof(self, standard_components):
- """Driver should support different DOF arms."""
- mock_sdk = MockSDK(dof=7)
-
- driver = BaseManipulatorDriver(
- sdk=mock_sdk,
- components=standard_components,
- config={"dof": 7},
- name="TestDriver",
- )
-
- assert driver.capabilities.dof == 7
- assert len(driver.capabilities.max_joint_velocity) == 7
-
- driver.stop()
-
-
-# =============================================================================
-# Component API Exposure Tests
-# =============================================================================
-
-
-class TestComponentAPIExposure:
- """Tests for auto-exposed component APIs."""
-
- def test_motion_component_api_exposed(self, driver):
- """Motion component APIs should be exposed on driver."""
- assert hasattr(driver, "move_joint")
- assert hasattr(driver, "stop_motion")
- assert callable(driver.move_joint)
-
- def test_servo_component_api_exposed(self, driver):
- """Servo component APIs should be exposed on driver."""
- assert hasattr(driver, "enable_servo")
- assert hasattr(driver, "disable_servo")
- assert callable(driver.enable_servo)
-
- def test_status_component_api_exposed(self, driver):
- """Status component APIs should be exposed on driver."""
- assert hasattr(driver, "get_joint_state")
- assert hasattr(driver, "get_robot_state")
- assert hasattr(driver, "get_joint_limits")
- assert callable(driver.get_joint_state)
-
-
-# =============================================================================
-# Threading Tests
-# =============================================================================
-
-
-class TestThreading:
- """Tests for driver threading behavior."""
-
- def test_start_creates_threads(self, driver):
- """start() should create control threads."""
- driver.start()
- time.sleep(0.05)
-
- assert len(driver.threads) >= 2
- assert all(t.is_alive() for t in driver.threads)
-
- driver.stop()
-
- def test_stop_terminates_threads(self, started_driver):
- """stop() should terminate all threads."""
- started_driver.stop()
- time.sleep(0.1)
-
- assert all(not t.is_alive() for t in started_driver.threads)
-
- def test_stop_calls_sdk_stop_motion(self, started_driver, mock_sdk):
- """stop() should call SDK stop_motion."""
- started_driver.stop()
-
- assert mock_sdk.stop_motion_called
-
-
-# =============================================================================
-# Call Verification Tests (MockSDK features)
-# =============================================================================
-
-
-class TestMockSDKCallTracking:
- """Tests for MockSDK call tracking features."""
-
- def test_call_count(self, mock_sdk):
- """MockSDK should count method calls."""
- mock_sdk.get_joint_positions()
- mock_sdk.get_joint_positions()
- mock_sdk.get_joint_positions()
-
- assert mock_sdk.call_count("get_joint_positions") == 3
-
- def test_was_called(self, mock_sdk):
- """MockSDK.was_called should report if method called."""
- assert not mock_sdk.was_called("enable_servos")
-
- mock_sdk.enable_servos()
-
- assert mock_sdk.was_called("enable_servos")
-
- def test_get_last_call_args(self, mock_sdk):
- """MockSDK should record call arguments."""
- positions = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
- mock_sdk.enable_servos()
- mock_sdk.set_joint_positions(positions, velocity=0.5, wait=True)
-
- call = mock_sdk.get_last_call("set_joint_positions")
-
- assert call is not None
- assert list(call.args[0]) == positions
- assert call.kwargs["velocity"] == 0.5
- assert call.kwargs["wait"] is True
-
- def test_reset_calls(self, mock_sdk):
- """MockSDK.reset_calls should clear call history."""
- mock_sdk.enable_servos()
- mock_sdk.get_joint_positions()
-
- mock_sdk.reset_calls()
-
- assert mock_sdk.call_count("enable_servos") == 0
- assert mock_sdk.call_count("get_joint_positions") == 0
- assert not mock_sdk.enable_servos_called
-
-
-# =============================================================================
-# Edge Case Tests
-# =============================================================================
-
-
-class TestEdgeCases:
- """Tests for edge cases and error handling."""
-
- def test_multiple_enable_calls_optimized(self, driver):
- """Multiple enable calls should only call SDK once (optimization)."""
- result1 = driver.enable_servo()
- result2 = driver.enable_servo()
- result3 = driver.enable_servo()
-
- # All calls succeed
- assert result1["success"] is True
- assert result2["success"] is True
- assert result3["success"] is True
-
- # But SDK only called once (component optimizes redundant calls)
- assert driver.sdk.call_count("enable_servos") == 1
-
- # Second and third calls should indicate already enabled
- assert result2.get("message") == "Servos already enabled"
- assert result3.get("message") == "Servos already enabled"
-
- def test_disable_when_already_disabled(self, driver):
- """Disable when already disabled should return success without SDK call."""
- # MockSDK starts with servos disabled
- result = driver.disable_servo()
-
- assert result["success"] is True
- assert result.get("message") == "Servos already disabled"
- # SDK not called since already disabled
- assert not driver.sdk.disable_servos_called
-
- def test_disable_after_enable(self, driver):
- """Disable after enable should call SDK."""
- driver.enable_servo()
- result = driver.disable_servo()
-
- assert result["success"] is True
- assert driver.sdk.disable_servos_called
-
- def test_emergency_stop(self, driver):
- """emergency_stop should disable servos."""
- driver.enable_servo()
-
- driver.sdk.emergency_stop()
-
- assert driver.sdk.emergency_stop_called
- assert not driver.sdk.are_servos_enabled()
diff --git a/dimos/hardware/manipulators/base/utils/__init__.py b/dimos/hardware/manipulators/base/utils/__init__.py
deleted file mode 100644
index a2dcb2f82e..0000000000
--- a/dimos/hardware/manipulators/base/utils/__init__.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Shared utilities for manipulator drivers."""
-
-from .converters import degrees_to_radians, meters_to_mm, mm_to_meters, radians_to_degrees
-from .shared_state import SharedState
-from .validators import (
- clamp_positions,
- scale_velocities,
- validate_acceleration_limits,
- validate_joint_limits,
- validate_trajectory,
- validate_velocity_limits,
-)
-
-__all__ = [
- "SharedState",
- "clamp_positions",
- "degrees_to_radians",
- "meters_to_mm",
- "mm_to_meters",
- "radians_to_degrees",
- "scale_velocities",
- "validate_acceleration_limits",
- "validate_joint_limits",
- "validate_trajectory",
- "validate_velocity_limits",
-]
diff --git a/dimos/hardware/manipulators/base/utils/converters.py b/dimos/hardware/manipulators/base/utils/converters.py
deleted file mode 100644
index dff5956f8e..0000000000
--- a/dimos/hardware/manipulators/base/utils/converters.py
+++ /dev/null
@@ -1,266 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Unit conversion utilities for manipulator drivers."""
-
-import math
-
-
-def degrees_to_radians(degrees: float | list[float]) -> float | list[float]:
- """Convert degrees to radians.
-
- Args:
- degrees: Angle(s) in degrees
-
- Returns:
- Angle(s) in radians
- """
- if isinstance(degrees, list):
- return [math.radians(d) for d in degrees]
- return math.radians(degrees)
-
-
-def radians_to_degrees(radians: float | list[float]) -> float | list[float]:
- """Convert radians to degrees.
-
- Args:
- radians: Angle(s) in radians
-
- Returns:
- Angle(s) in degrees
- """
- if isinstance(radians, list):
- return [math.degrees(r) for r in radians]
- return math.degrees(radians)
-
-
-def mm_to_meters(mm: float | list[float]) -> float | list[float]:
- """Convert millimeters to meters.
-
- Args:
- mm: Distance(s) in millimeters
-
- Returns:
- Distance(s) in meters
- """
- if isinstance(mm, list):
- return [m / 1000.0 for m in mm]
- return mm / 1000.0
-
-
-def meters_to_mm(meters: float | list[float]) -> float | list[float]:
- """Convert meters to millimeters.
-
- Args:
- meters: Distance(s) in meters
-
- Returns:
- Distance(s) in millimeters
- """
- if isinstance(meters, list):
- return [m * 1000.0 for m in meters]
- return meters * 1000.0
-
-
-def rpm_to_rad_per_sec(rpm: float | list[float]) -> float | list[float]:
- """Convert RPM to rad/s.
-
- Args:
- rpm: Angular velocity in RPM
-
- Returns:
- Angular velocity in rad/s
- """
- factor = (2 * math.pi) / 60.0
- if isinstance(rpm, list):
- return [r * factor for r in rpm]
- return rpm * factor
-
-
-def rad_per_sec_to_rpm(rad_per_sec: float | list[float]) -> float | list[float]:
- """Convert rad/s to RPM.
-
- Args:
- rad_per_sec: Angular velocity in rad/s
-
- Returns:
- Angular velocity in RPM
- """
- factor = 60.0 / (2 * math.pi)
- if isinstance(rad_per_sec, list):
- return [r * factor for r in rad_per_sec]
- return rad_per_sec * factor
-
-
-def quaternion_to_euler(qx: float, qy: float, qz: float, qw: float) -> tuple[float, float, float]:
- """Convert quaternion to Euler angles (roll, pitch, yaw).
-
- Args:
- qx, qy, qz, qw: Quaternion components
-
- Returns:
- Tuple of (roll, pitch, yaw) in radians
- """
- # Roll (x-axis rotation)
- sinr_cosp = 2 * (qw * qx + qy * qz)
- cosr_cosp = 1 - 2 * (qx * qx + qy * qy)
- roll = math.atan2(sinr_cosp, cosr_cosp)
-
- # Pitch (y-axis rotation)
- sinp = 2 * (qw * qy - qz * qx)
- if abs(sinp) >= 1:
- pitch = math.copysign(math.pi / 2, sinp) # Use 90 degrees if out of range
- else:
- pitch = math.asin(sinp)
-
- # Yaw (z-axis rotation)
- siny_cosp = 2 * (qw * qz + qx * qy)
- cosy_cosp = 1 - 2 * (qy * qy + qz * qz)
- yaw = math.atan2(siny_cosp, cosy_cosp)
-
- return roll, pitch, yaw
-
-
-def euler_to_quaternion(roll: float, pitch: float, yaw: float) -> tuple[float, float, float, float]:
- """Convert Euler angles to quaternion.
-
- Args:
- roll, pitch, yaw: Euler angles in radians
-
- Returns:
- Tuple of (qx, qy, qz, qw) quaternion components
- """
- cy = math.cos(yaw * 0.5)
- sy = math.sin(yaw * 0.5)
- cp = math.cos(pitch * 0.5)
- sp = math.sin(pitch * 0.5)
- cr = math.cos(roll * 0.5)
- sr = math.sin(roll * 0.5)
-
- qw = cr * cp * cy + sr * sp * sy
- qx = sr * cp * cy - cr * sp * sy
- qy = cr * sp * cy + sr * cp * sy
- qz = cr * cp * sy - sr * sp * cy
-
- return qx, qy, qz, qw
-
-
-def pose_dict_to_list(pose: dict[str, float]) -> list[float]:
- """Convert pose dictionary to list format.
-
- Args:
- pose: Dict with keys: x, y, z, roll, pitch, yaw
-
- Returns:
- List [x, y, z, roll, pitch, yaw]
- """
- return [
- pose.get("x", 0.0),
- pose.get("y", 0.0),
- pose.get("z", 0.0),
- pose.get("roll", 0.0),
- pose.get("pitch", 0.0),
- pose.get("yaw", 0.0),
- ]
-
-
-def pose_list_to_dict(pose: list[float]) -> dict[str, float]:
- """Convert pose list to dictionary format.
-
- Args:
- pose: List [x, y, z, roll, pitch, yaw]
-
- Returns:
- Dict with keys: x, y, z, roll, pitch, yaw
- """
- if len(pose) < 6:
- raise ValueError(f"Pose list must have 6 elements, got {len(pose)}")
-
- return {
- "x": pose[0],
- "y": pose[1],
- "z": pose[2],
- "roll": pose[3],
- "pitch": pose[4],
- "yaw": pose[5],
- }
-
-
-def twist_dict_to_list(twist: dict[str, float]) -> list[float]:
- """Convert twist dictionary to list format.
-
- Args:
- twist: Dict with keys: vx, vy, vz, wx, wy, wz
-
- Returns:
- List [vx, vy, vz, wx, wy, wz]
- """
- return [
- twist.get("vx", 0.0),
- twist.get("vy", 0.0),
- twist.get("vz", 0.0),
- twist.get("wx", 0.0),
- twist.get("wy", 0.0),
- twist.get("wz", 0.0),
- ]
-
-
-def twist_list_to_dict(twist: list[float]) -> dict[str, float]:
- """Convert twist list to dictionary format.
-
- Args:
- twist: List [vx, vy, vz, wx, wy, wz]
-
- Returns:
- Dict with keys: vx, vy, vz, wx, wy, wz
- """
- if len(twist) < 6:
- raise ValueError(f"Twist list must have 6 elements, got {len(twist)}")
-
- return {
- "vx": twist[0],
- "vy": twist[1],
- "vz": twist[2],
- "wx": twist[3],
- "wy": twist[4],
- "wz": twist[5],
- }
-
-
-def normalize_angle(angle: float) -> float:
- """Normalize angle to [-pi, pi].
-
- Args:
- angle: Angle in radians
-
- Returns:
- Normalized angle in [-pi, pi]
- """
- while angle > math.pi:
- angle -= 2 * math.pi
- while angle < -math.pi:
- angle += 2 * math.pi
- return angle
-
-
-def normalize_angles(angles: list[float]) -> list[float]:
- """Normalize angles to [-pi, pi].
-
- Args:
- angles: Angles in radians
-
- Returns:
- Normalized angles in [-pi, pi]
- """
- return [normalize_angle(a) for a in angles]
diff --git a/dimos/hardware/manipulators/base/utils/shared_state.py b/dimos/hardware/manipulators/base/utils/shared_state.py
deleted file mode 100644
index 8af275ea17..0000000000
--- a/dimos/hardware/manipulators/base/utils/shared_state.py
+++ /dev/null
@@ -1,255 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Thread-safe shared state for manipulator drivers."""
-
-from dataclasses import dataclass, field
-from threading import Lock
-import time
-from typing import Any
-
-
-@dataclass
-class SharedState:
- """Thread-safe shared state for manipulator drivers.
-
- This class holds the current state of the manipulator that needs to be
- shared between multiple threads (state reader, command sender, publisher).
- All access should be protected by the lock.
- """
-
- # Thread synchronization
- lock: Lock = field(default_factory=Lock)
-
- # Joint state (current values from hardware)
- joint_positions: list[float] | None = None # radians
- joint_velocities: list[float] | None = None # rad/s
- joint_efforts: list[float] | None = None # Nm
-
- # Joint targets (commanded values)
- target_positions: list[float] | None = None # radians
- target_velocities: list[float] | None = None # rad/s
- target_efforts: list[float] | None = None # Nm
-
- # Cartesian state (if available)
- cartesian_position: dict[str, float] | None = None # x,y,z,roll,pitch,yaw
- cartesian_velocity: dict[str, float] | None = None # vx,vy,vz,wx,wy,wz
-
- # Cartesian targets
- target_cartesian_position: dict[str, float] | None = None
- target_cartesian_velocity: dict[str, float] | None = None
-
- # Force/torque sensor (if available)
- force_torque: list[float] | None = None # [fx,fy,fz,tx,ty,tz]
-
- # System state
- robot_state: int = 0 # 0=idle, 1=moving, 2=error, 3=e-stop
- control_mode: int = 0 # 0=position, 1=velocity, 2=torque
- error_code: int = 0 # 0 = no error
- error_message: str = "" # Human-readable error
-
- # Connection and enable status
- is_connected: bool = False
- is_enabled: bool = False
- is_moving: bool = False
- is_homed: bool = False
-
- # Gripper state (if available)
- gripper_position: float | None = None # meters
- gripper_force: float | None = None # Newtons
-
- # Timestamps
- last_state_update: float = 0.0
- last_command_sent: float = 0.0
- last_error_time: float = 0.0
-
- # Statistics
- state_read_count: int = 0
- command_sent_count: int = 0
- error_count: int = 0
-
- def update_joint_state(
- self,
- positions: list[float] | None = None,
- velocities: list[float] | None = None,
- efforts: list[float] | None = None,
- ) -> None:
- """Thread-safe update of joint state.
-
- Args:
- positions: Joint positions in radians
- velocities: Joint velocities in rad/s
- efforts: Joint efforts in Nm
- """
- with self.lock:
- if positions is not None:
- self.joint_positions = positions
- if velocities is not None:
- self.joint_velocities = velocities
- if efforts is not None:
- self.joint_efforts = efforts
- self.last_state_update = time.time()
- self.state_read_count += 1
-
- def update_robot_state(
- self,
- state: int | None = None,
- mode: int | None = None,
- error_code: int | None = None,
- error_message: str | None = None,
- ) -> None:
- """Thread-safe update of robot state.
-
- Args:
- state: Robot state code
- mode: Control mode code
- error_code: Error code (0 = no error)
- error_message: Human-readable error message
- """
- with self.lock:
- if state is not None:
- self.robot_state = state
- if mode is not None:
- self.control_mode = mode
- if error_code is not None:
- self.error_code = error_code
- if error_code != 0:
- self.error_count += 1
- self.last_error_time = time.time()
- if error_message is not None:
- self.error_message = error_message
-
- def update_cartesian_state(
- self, position: dict[str, float] | None = None, velocity: dict[str, float] | None = None
- ) -> None:
- """Thread-safe update of Cartesian state.
-
- Args:
- position: End-effector pose (x,y,z,roll,pitch,yaw)
- velocity: End-effector twist (vx,vy,vz,wx,wy,wz)
- """
- with self.lock:
- if position is not None:
- self.cartesian_position = position
- if velocity is not None:
- self.cartesian_velocity = velocity
-
- def set_target_joints(
- self,
- positions: list[float] | None = None,
- velocities: list[float] | None = None,
- efforts: list[float] | None = None,
- ) -> None:
- """Thread-safe update of joint targets.
-
- Args:
- positions: Target positions in radians
- velocities: Target velocities in rad/s
- efforts: Target efforts in Nm
- """
- with self.lock:
- if positions is not None:
- self.target_positions = positions
- if velocities is not None:
- self.target_velocities = velocities
- if efforts is not None:
- self.target_efforts = efforts
- self.last_command_sent = time.time()
- self.command_sent_count += 1
-
- def get_joint_state(
- self,
- ) -> tuple[list[float] | None, list[float] | None, list[float] | None]:
- """Thread-safe read of joint state.
-
- Returns:
- Tuple of (positions, velocities, efforts)
- """
- with self.lock:
- return (
- self.joint_positions.copy() if self.joint_positions else None,
- self.joint_velocities.copy() if self.joint_velocities else None,
- self.joint_efforts.copy() if self.joint_efforts else None,
- )
-
- def get_robot_state(self) -> dict[str, Any]:
- """Thread-safe read of robot state.
-
- Returns:
- Dict with state information
- """
- with self.lock:
- return {
- "state": self.robot_state,
- "mode": self.control_mode,
- "error_code": self.error_code,
- "error_message": self.error_message,
- "is_connected": self.is_connected,
- "is_enabled": self.is_enabled,
- "is_moving": self.is_moving,
- "last_update": self.last_state_update,
- }
-
- def get_statistics(self) -> dict[str, Any]:
- """Get statistics about state updates.
-
- Returns:
- Dict with statistics
- """
- with self.lock:
- return {
- "state_read_count": self.state_read_count,
- "command_sent_count": self.command_sent_count,
- "error_count": self.error_count,
- "last_state_update": self.last_state_update,
- "last_command_sent": self.last_command_sent,
- "last_error_time": self.last_error_time,
- }
-
- def clear_errors(self) -> None:
- """Clear error state."""
- with self.lock:
- self.error_code = 0
- self.error_message = ""
-
- def reset(self) -> None:
- """Reset all state to initial values."""
- with self.lock:
- self.joint_positions = None
- self.joint_velocities = None
- self.joint_efforts = None
- self.target_positions = None
- self.target_velocities = None
- self.target_efforts = None
- self.cartesian_position = None
- self.cartesian_velocity = None
- self.target_cartesian_position = None
- self.target_cartesian_velocity = None
- self.force_torque = None
- self.robot_state = 0
- self.control_mode = 0
- self.error_code = 0
- self.error_message = ""
- self.is_connected = False
- self.is_enabled = False
- self.is_moving = False
- self.is_homed = False
- self.gripper_position = None
- self.gripper_force = None
- self.last_state_update = 0.0
- self.last_command_sent = 0.0
- self.last_error_time = 0.0
- self.state_read_count = 0
- self.command_sent_count = 0
- self.error_count = 0
diff --git a/dimos/hardware/manipulators/base/utils/validators.py b/dimos/hardware/manipulators/base/utils/validators.py
deleted file mode 100644
index 3fabdcd306..0000000000
--- a/dimos/hardware/manipulators/base/utils/validators.py
+++ /dev/null
@@ -1,254 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Validation utilities for manipulator drivers."""
-
-from typing import cast
-
-
-def validate_joint_limits(
- positions: list[float],
- lower_limits: list[float],
- upper_limits: list[float],
- tolerance: float = 0.0,
-) -> tuple[bool, str | None]:
- """Validate joint positions are within limits.
-
- Args:
- positions: Joint positions to validate (radians)
- lower_limits: Lower joint limits (radians)
- upper_limits: Upper joint limits (radians)
- tolerance: Optional tolerance for soft limits (radians)
-
- Returns:
- Tuple of (is_valid, error_message)
- If valid, error_message is None
- """
- if len(positions) != len(lower_limits) or len(positions) != len(upper_limits):
- return False, f"Dimension mismatch: {len(positions)} positions, {len(lower_limits)} limits"
-
- for i, pos in enumerate(positions):
- lower = lower_limits[i] - tolerance
- upper = upper_limits[i] + tolerance
-
- if pos < lower:
- return False, f"Joint {i} position {pos:.3f} below limit {lower_limits[i]:.3f}"
-
- if pos > upper:
- return False, f"Joint {i} position {pos:.3f} above limit {upper_limits[i]:.3f}"
-
- return True, None
-
-
-def validate_velocity_limits(
- velocities: list[float], max_velocities: list[float], scale_factor: float = 1.0
-) -> tuple[bool, str | None]:
- """Validate joint velocities are within limits.
-
- Args:
- velocities: Joint velocities to validate (rad/s)
- max_velocities: Maximum allowed velocities (rad/s)
- scale_factor: Optional scaling factor (0-1) to reduce max velocity
-
- Returns:
- Tuple of (is_valid, error_message)
- If valid, error_message is None
- """
- if len(velocities) != len(max_velocities):
- return (
- False,
- f"Dimension mismatch: {len(velocities)} velocities, {len(max_velocities)} limits",
- )
-
- if scale_factor <= 0 or scale_factor > 1:
- return False, f"Invalid scale factor: {scale_factor} (must be in (0, 1])"
-
- for i, vel in enumerate(velocities):
- max_vel = max_velocities[i] * scale_factor
-
- if abs(vel) > max_vel:
- return False, f"Joint {i} velocity {abs(vel):.3f} exceeds limit {max_vel:.3f}"
-
- return True, None
-
-
-def validate_acceleration_limits(
- accelerations: list[float], max_accelerations: list[float], scale_factor: float = 1.0
-) -> tuple[bool, str | None]:
- """Validate joint accelerations are within limits.
-
- Args:
- accelerations: Joint accelerations to validate (rad/s²)
- max_accelerations: Maximum allowed accelerations (rad/s²)
- scale_factor: Optional scaling factor (0-1) to reduce max acceleration
-
- Returns:
- Tuple of (is_valid, error_message)
- If valid, error_message is None
- """
- if len(accelerations) != len(max_accelerations):
- return (
- False,
- f"Dimension mismatch: {len(accelerations)} accelerations, {len(max_accelerations)} limits",
- )
-
- if scale_factor <= 0 or scale_factor > 1:
- return False, f"Invalid scale factor: {scale_factor} (must be in (0, 1])"
-
- for i, acc in enumerate(accelerations):
- max_acc = max_accelerations[i] * scale_factor
-
- if abs(acc) > max_acc:
- return False, f"Joint {i} acceleration {abs(acc):.3f} exceeds limit {max_acc:.3f}"
-
- return True, None
-
-
-def validate_trajectory(
- trajectory: list[dict[str, float | list[float]]],
- lower_limits: list[float],
- upper_limits: list[float],
- max_velocities: list[float] | None = None,
- max_accelerations: list[float] | None = None,
-) -> tuple[bool, str | None]:
- """Validate a joint trajectory.
-
- Args:
- trajectory: List of waypoints, each with:
- - 'positions': list[float] in radians
- - 'velocities': Optional list[float] in rad/s
- - 'time': float seconds from start
- lower_limits: Lower joint limits (radians)
- upper_limits: Upper joint limits (radians)
- max_velocities: Optional maximum velocities (rad/s)
- max_accelerations: Optional maximum accelerations (rad/s²)
-
- Returns:
- Tuple of (is_valid, error_message)
- If valid, error_message is None
- """
- if not trajectory:
- return False, "Empty trajectory"
-
- # Check first waypoint starts at time 0
- if trajectory[0].get("time", 0) != 0:
- return False, "Trajectory must start at time 0"
-
- # Check waypoints are time-ordered
- prev_time: float = -1.0
- for i, waypoint in enumerate(trajectory):
- curr_time = cast("float", waypoint.get("time", 0))
- if curr_time <= prev_time:
- return False, f"Waypoint {i} time {curr_time} not after previous {prev_time}"
- prev_time = curr_time
-
- # Validate each waypoint
- for i, waypoint in enumerate(trajectory):
- # Check required fields
- if "positions" not in waypoint:
- return False, f"Waypoint {i} missing positions"
-
- positions = cast("list[float]", waypoint["positions"])
-
- # Validate position limits
- valid, error = validate_joint_limits(positions, lower_limits, upper_limits)
- if not valid:
- return False, f"Waypoint {i}: {error}"
-
- # Validate velocity limits if provided
- if "velocities" in waypoint and max_velocities:
- velocities = cast("list[float]", waypoint["velocities"])
- valid, error = validate_velocity_limits(velocities, max_velocities)
- if not valid:
- return False, f"Waypoint {i}: {error}"
-
- # Check acceleration limits between waypoints
- if max_accelerations and len(trajectory) > 1:
- for i in range(1, len(trajectory)):
- prev = trajectory[i - 1]
- curr = trajectory[i]
-
- dt = cast("float", curr["time"]) - cast("float", prev["time"])
- if dt <= 0:
- continue
-
- # Estimate acceleration from position change
- prev_pos = cast("list[float]", prev["positions"])
- curr_pos = cast("list[float]", curr["positions"])
- for j in range(len(prev_pos)):
- pos_change = curr_pos[j] - prev_pos[j]
- pos_change / dt
-
- # If velocities provided, use them for better estimate
- if "velocities" in prev and "velocities" in curr:
- prev_vel = cast("list[float]", prev["velocities"])
- curr_vel = cast("list[float]", curr["velocities"])
- vel_change = curr_vel[j] - prev_vel[j]
- acc = vel_change / dt
- if abs(acc) > max_accelerations[j]:
- return (
- False,
- f"Acceleration between waypoint {i - 1} and {i} joint {j}: {abs(acc):.3f} exceeds limit {max_accelerations[j]:.3f}",
- )
-
- return True, None
-
-
-def scale_velocities(
- velocities: list[float], max_velocities: list[float], scale_factor: float = 0.8
-) -> list[float]:
- """Scale velocities to stay within limits.
-
- Args:
- velocities: Desired velocities (rad/s)
- max_velocities: Maximum allowed velocities (rad/s)
- scale_factor: Safety factor (0-1) to stay below limits
-
- Returns:
- Scaled velocities that respect limits
- """
- if not velocities or not max_velocities:
- return velocities
-
- # Find the joint that requires most scaling
- max_scale = 1.0
- for vel, max_vel in zip(velocities, max_velocities, strict=False):
- if max_vel > 0 and abs(vel) > 0:
- required_scale = abs(vel) / (max_vel * scale_factor)
- max_scale = max(max_scale, required_scale)
-
- # Apply uniform scaling to maintain direction
- if max_scale > 1.0:
- return [v / max_scale for v in velocities]
-
- return velocities
-
-
-def clamp_positions(
- positions: list[float], lower_limits: list[float], upper_limits: list[float]
-) -> list[float]:
- """Clamp positions to stay within limits.
-
- Args:
- positions: Desired positions (radians)
- lower_limits: Lower joint limits (radians)
- upper_limits: Upper joint limits (radians)
-
- Returns:
- Clamped positions within limits
- """
- clamped = []
- for pos, lower, upper in zip(positions, lower_limits, upper_limits, strict=False):
- clamped.append(max(lower, min(upper, pos)))
- return clamped
diff --git a/dimos/hardware/manipulators/mock/__init__.py b/dimos/hardware/manipulators/mock/__init__.py
new file mode 100644
index 0000000000..87428973a4
--- /dev/null
+++ b/dimos/hardware/manipulators/mock/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Mock backend for testing manipulator drivers without hardware.
+
+Usage:
+ >>> from dimos.hardware.manipulators.xarm import XArm
+ >>> from dimos.hardware.manipulators.mock import MockBackend
+ >>> arm = XArm(backend=MockBackend())
+ >>> arm.start() # No hardware needed!
+ >>> arm.move_joint([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
+ >>> assert arm.backend.read_joint_positions() == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
+"""
+
+from dimos.hardware.manipulators.mock.backend import MockBackend
+
+__all__ = ["MockBackend"]
diff --git a/dimos/hardware/manipulators/mock/backend.py b/dimos/hardware/manipulators/mock/backend.py
new file mode 100644
index 0000000000..80b3543739
--- /dev/null
+++ b/dimos/hardware/manipulators/mock/backend.py
@@ -0,0 +1,250 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Mock backend for testing - no hardware required.
+
+Usage:
+ >>> from dimos.hardware.manipulators.xarm import XArm
+ >>> from dimos.hardware.manipulators.mock import MockBackend
+ >>> arm = XArm(backend=MockBackend())
+ >>> arm.start() # No hardware!
+"""
+
+import math
+
+from dimos.hardware.manipulators.spec import (
+ ControlMode,
+ JointLimits,
+ ManipulatorInfo,
+)
+
+
+class MockBackend:
+ """Fake backend for unit tests.
+
+ Implements ManipulatorBackend protocol with in-memory state.
+ Useful for:
+ - Unit testing driver logic without hardware
+ - Integration testing with predictable behavior
+ - Development without physical robot
+ """
+
+ def __init__(self, dof: int = 6) -> None:
+ self._dof = dof
+ self._positions = [0.0] * dof
+ self._velocities = [0.0] * dof
+ self._efforts = [0.0] * dof
+ self._enabled = False
+ self._connected = False
+ self._control_mode = ControlMode.POSITION
+ self._cartesian_position: dict[str, float] = {
+ "x": 0.3,
+ "y": 0.0,
+ "z": 0.3,
+ "roll": 0.0,
+ "pitch": 0.0,
+ "yaw": 0.0,
+ }
+ self._gripper_position: float = 0.0
+ self._error_code: int = 0
+ self._error_message: str = ""
+
+ # =========================================================================
+ # Connection
+ # =========================================================================
+
+ def connect(self) -> bool:
+ """Simulate connection."""
+ self._connected = True
+ return True
+
+ def disconnect(self) -> None:
+ """Simulate disconnection."""
+ self._connected = False
+
+ def is_connected(self) -> bool:
+ """Check mock connection status."""
+ return self._connected
+
+ # =========================================================================
+ # Info
+ # =========================================================================
+
+ def get_info(self) -> ManipulatorInfo:
+ """Return mock info."""
+ return ManipulatorInfo(
+ vendor="Mock",
+ model="MockArm",
+ dof=self._dof,
+ firmware_version="1.0.0",
+ serial_number="MOCK-001",
+ )
+
+ def get_dof(self) -> int:
+ """Return DOF."""
+ return self._dof
+
+ def get_limits(self) -> JointLimits:
+ """Return mock joint limits."""
+ return JointLimits(
+ position_lower=[-math.pi] * self._dof,
+ position_upper=[math.pi] * self._dof,
+ velocity_max=[1.0] * self._dof,
+ )
+
+ # =========================================================================
+ # Control Mode
+ # =========================================================================
+
+ def set_control_mode(self, mode: ControlMode) -> bool:
+ """Set mock control mode."""
+ self._control_mode = mode
+ return True
+
+ def get_control_mode(self) -> ControlMode:
+ """Get mock control mode."""
+ return self._control_mode
+
+ # =========================================================================
+ # State Reading
+ # =========================================================================
+
+ def read_joint_positions(self) -> list[float]:
+ """Return mock joint positions."""
+ return self._positions.copy()
+
+ def read_joint_velocities(self) -> list[float]:
+ """Return mock joint velocities."""
+ return self._velocities.copy()
+
+ def read_joint_efforts(self) -> list[float]:
+ """Return mock joint efforts."""
+ return self._efforts.copy()
+
+ def read_state(self) -> dict[str, int]:
+ """Return mock state."""
+ # Use index of control mode as int (0=position, 1=velocity, etc.)
+ mode_int = list(ControlMode).index(self._control_mode)
+ return {
+ "state": 0 if self._enabled else 1,
+ "mode": mode_int,
+ }
+
+ def read_error(self) -> tuple[int, str]:
+ """Return mock error."""
+ return self._error_code, self._error_message
+
+ # =========================================================================
+ # Motion Control
+ # =========================================================================
+
+ def write_joint_positions(
+ self,
+ positions: list[float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Set mock joint positions (instant move)."""
+ if len(positions) != self._dof:
+ return False
+ self._positions = list(positions)
+ return True
+
+ def write_joint_velocities(self, velocities: list[float]) -> bool:
+ """Set mock joint velocities."""
+ if len(velocities) != self._dof:
+ return False
+ self._velocities = list(velocities)
+ return True
+
+ def write_stop(self) -> bool:
+ """Stop mock motion."""
+ self._velocities = [0.0] * self._dof
+ return True
+
+ # =========================================================================
+ # Servo Control
+ # =========================================================================
+
+ def write_enable(self, enable: bool) -> bool:
+ """Enable/disable mock servos."""
+ self._enabled = enable
+ return True
+
+ def read_enabled(self) -> bool:
+ """Check mock servo state."""
+ return self._enabled
+
+ def write_clear_errors(self) -> bool:
+ """Clear mock errors."""
+ self._error_code = 0
+ self._error_message = ""
+ return True
+
+ # =========================================================================
+ # Cartesian Control (Optional)
+ # =========================================================================
+
+ def read_cartesian_position(self) -> dict[str, float] | None:
+ """Return mock cartesian position."""
+ return self._cartesian_position.copy()
+
+ def write_cartesian_position(
+ self,
+ pose: dict[str, float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Set mock cartesian position."""
+ self._cartesian_position.update(pose)
+ return True
+
+ # =========================================================================
+ # Gripper (Optional)
+ # =========================================================================
+
+ def read_gripper_position(self) -> float | None:
+ """Return mock gripper position."""
+ return self._gripper_position
+
+ def write_gripper_position(self, position: float) -> bool:
+ """Set mock gripper position."""
+ self._gripper_position = position
+ return True
+
+ # =========================================================================
+ # Force/Torque (Optional)
+ # =========================================================================
+
+ def read_force_torque(self) -> list[float] | None:
+ """Return mock F/T sensor data (not supported in mock)."""
+ return None
+
+ # =========================================================================
+ # Test Helpers (not part of Protocol)
+ # =========================================================================
+
+ def set_error(self, code: int, message: str) -> None:
+ """Inject an error for testing error handling."""
+ self._error_code = code
+ self._error_message = message
+
+ def set_positions(self, positions: list[float]) -> None:
+ """Set positions directly for testing."""
+ self._positions = list(positions)
+
+ def set_efforts(self, efforts: list[float]) -> None:
+ """Set efforts directly for testing."""
+ self._efforts = list(efforts)
+
+
+__all__ = ["MockBackend"]
diff --git a/dimos/hardware/manipulators/piper/README.md b/dimos/hardware/manipulators/piper/README.md
deleted file mode 100644
index 89ff2161ac..0000000000
--- a/dimos/hardware/manipulators/piper/README.md
+++ /dev/null
@@ -1,35 +0,0 @@
-# Piper Driver
-
-Driver for the Piper 6-DOF manipulator with CAN bus communication.
-
-## Supported Features
-
-✅ **Joint Control**
-- Position control
-- Velocity control (integration-based)
-- Joint state feedback at 100Hz
-
-✅ **System Control**
-- Enable/disable motors
-- Emergency stop
-- Error recovery
-
-✅ **Gripper Control**
-- Position and force control
-- Gripper state feedback
-
-## Cartesian Control Limitation
-
-⚠️ **Cartesian control is currently NOT available for the Piper arm.**
-
-### Why?
-The Piper SDK doesn't expose an inverse kinematics (IK) solver that can be called without moving the robot. While the robot can execute Cartesian commands internally, we cannot:
-- Pre-compute joint trajectories for Cartesian paths
-- Validate if a pose is reachable without trying to move there
-- Plan complex Cartesian trajectories offline
-
-### Future Solution
-We will implement a universal IK solver that sits outside the driver layer and works with all arms (XArm, Piper, and future robots), regardless of whether they expose internal IK.
-
-### Current Workaround
-Use joint-space control for now. If you need Cartesian planning, consider using external IK libraries like ikpy or robotics-toolbox-python with the Piper's URDF file.
diff --git a/dimos/hardware/manipulators/piper/__init__.py b/dimos/hardware/manipulators/piper/__init__.py
index acead9f7fb..16c6e451cd 100644
--- a/dimos/hardware/manipulators/piper/__init__.py
+++ b/dimos/hardware/manipulators/piper/__init__.py
@@ -12,21 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Piper Arm Driver
+"""Piper manipulator hardware backend.
-Real-time driver for Piper manipulator with CAN bus communication.
+Usage:
+ >>> from dimos.hardware.manipulators.piper import PiperBackend
+ >>> backend = PiperBackend(can_port="can0")
+ >>> backend.connect()
+ >>> positions = backend.read_joint_positions()
"""
-from .piper_blueprints import piper_cartesian, piper_servo, piper_trajectory
-from .piper_driver import PiperDriver, piper_driver
-from .piper_wrapper import PiperSDKWrapper
+from dimos.hardware.manipulators.piper.backend import PiperBackend
-__all__ = [
- "PiperDriver",
- "PiperSDKWrapper",
- "piper_cartesian",
- "piper_driver",
- "piper_servo",
- "piper_trajectory",
-]
+__all__ = ["PiperBackend"]
diff --git a/dimos/hardware/manipulators/piper/backend.py b/dimos/hardware/manipulators/piper/backend.py
new file mode 100644
index 0000000000..1ce91dccd1
--- /dev/null
+++ b/dimos/hardware/manipulators/piper/backend.py
@@ -0,0 +1,505 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Piper backend - implements ManipulatorBackend protocol.
+
+Handles all Piper SDK communication and unit conversion.
+"""
+
+import math
+import time
+from typing import Any
+
+from dimos.hardware.manipulators.spec import (
+ ControlMode,
+ JointLimits,
+ ManipulatorBackend,
+ ManipulatorInfo,
+)
+
+# Unit conversion constants
+# Piper uses 0.001 degrees internally
+RAD_TO_PIPER = 57295.7795 # radians to Piper units (0.001 degrees)
+PIPER_TO_RAD = 1.0 / RAD_TO_PIPER # Piper units to radians
+
+
+class PiperBackend(ManipulatorBackend):
+ """Piper-specific backend.
+
+ Implements ManipulatorBackend protocol via duck typing.
+ No inheritance required - just matching method signatures.
+
+ Unit conversions:
+ - Angles: Piper uses 0.001 degrees, we use radians
+ - Velocities: Piper uses internal units, we use rad/s
+ """
+
+ def __init__(self, can_port: str = "can0", dof: int = 6) -> None:
+ if dof != 6:
+ raise ValueError(f"PiperBackend only supports 6 DOF (got {dof})")
+ self._can_port = can_port
+ self._dof = dof
+ self._sdk: Any = None
+ self._connected: bool = False
+ self._enabled: bool = False
+ self._control_mode: ControlMode = ControlMode.POSITION
+
+ # =========================================================================
+ # Connection
+ # =========================================================================
+
+ def connect(self) -> bool:
+ """Connect to Piper via CAN bus."""
+ try:
+ from piper_sdk import C_PiperInterface_V2
+
+ self._sdk = C_PiperInterface_V2(
+ can_name=self._can_port,
+ judge_flag=True, # Enable safety checks
+ can_auto_init=True, # Let SDK handle CAN initialization
+ dh_is_offset=False,
+ )
+
+ # Connect to CAN port
+ self._sdk.ConnectPort(piper_init=True, start_thread=True)
+
+ # Wait for initialization
+ time.sleep(0.025)
+
+ # Check connection by trying to get status
+ status = self._sdk.GetArmStatus()
+ if status is not None:
+ self._connected = True
+ print(f"Piper connected via CAN port {self._can_port}")
+ return True
+ else:
+ print(f"ERROR: Failed to connect to Piper on {self._can_port} - no status received")
+ return False
+
+ except ImportError:
+ print("ERROR: Piper SDK not installed. Please install piper_sdk")
+ return False
+ except Exception as e:
+ print(f"ERROR: Failed to connect to Piper on {self._can_port}: {e}")
+ return False
+
+ def disconnect(self) -> None:
+ """Disconnect from Piper."""
+ if self._sdk:
+ try:
+ if self._enabled:
+ self._sdk.DisablePiper()
+ self._enabled = False
+ self._sdk.DisconnectPort()
+ except Exception:
+ pass
+ finally:
+ self._sdk = None
+ self._connected = False
+
+ def is_connected(self) -> bool:
+ """Check if connected to Piper."""
+ if not self._connected or not self._sdk:
+ return False
+
+ try:
+ status = self._sdk.GetArmStatus()
+ return status is not None
+ except Exception:
+ return False
+
+ # =========================================================================
+ # Info
+ # =========================================================================
+
+ def get_info(self) -> ManipulatorInfo:
+ """Get Piper information."""
+ firmware_version = None
+ if self._sdk:
+ try:
+ firmware_version = self._sdk.GetPiperFirmwareVersion()
+ except Exception:
+ pass
+
+ return ManipulatorInfo(
+ vendor="Agilex",
+ model="Piper",
+ dof=self._dof,
+ firmware_version=firmware_version,
+ )
+
+ def get_dof(self) -> int:
+ """Get degrees of freedom."""
+ return self._dof
+
+ def get_limits(self) -> JointLimits:
+ """Get joint limits."""
+ # Piper joint limits (approximate, in radians)
+ lower = [-3.14, -2.35, -2.35, -3.14, -2.35, -3.14]
+ upper = [3.14, 2.35, 2.35, 3.14, 2.35, 3.14]
+ max_vel = [math.pi] * self._dof # ~180 deg/s
+
+ return JointLimits(
+ position_lower=lower,
+ position_upper=upper,
+ velocity_max=max_vel,
+ )
+
+ # =========================================================================
+ # Control Mode
+ # =========================================================================
+
+ def set_control_mode(self, mode: ControlMode) -> bool:
+ """Set Piper control mode via MotionCtrl_2."""
+ if not self._sdk:
+ return False
+
+ # Piper move modes: 0x01=position, 0x02=velocity
+ # SERVO_POSITION uses position mode for high-freq streaming
+ move_mode = 0x01 # Default position mode
+ if mode == ControlMode.VELOCITY:
+ move_mode = 0x02
+
+ try:
+ self._sdk.MotionCtrl_2(
+ ctrl_mode=0x01, # CAN control mode
+ move_mode=move_mode,
+ move_spd_rate_ctrl=50, # Speed rate (0-100)
+ is_mit_mode=0x00, # Not MIT mode
+ )
+ self._control_mode = mode
+ return True
+ except Exception:
+ return False
+
+ def get_control_mode(self) -> ControlMode:
+ """Get current control mode."""
+ return self._control_mode
+
+ # =========================================================================
+ # State Reading
+ # =========================================================================
+
+ def read_joint_positions(self) -> list[float]:
+ """Read joint positions (Piper units -> radians)."""
+ if not self._sdk:
+ raise RuntimeError("Not connected")
+
+ joint_msgs = self._sdk.GetArmJointMsgs()
+ if not joint_msgs or not joint_msgs.joint_state:
+ raise RuntimeError("Failed to read joint positions")
+
+ js = joint_msgs.joint_state
+ return [
+ js.joint_1 * PIPER_TO_RAD,
+ js.joint_2 * PIPER_TO_RAD,
+ js.joint_3 * PIPER_TO_RAD,
+ js.joint_4 * PIPER_TO_RAD,
+ js.joint_5 * PIPER_TO_RAD,
+ js.joint_6 * PIPER_TO_RAD,
+ ]
+
+ def read_joint_velocities(self) -> list[float]:
+ """Read joint velocities.
+
+ Note: Piper doesn't provide real-time velocity feedback.
+ Returns zeros. For velocity estimation, use finite differences.
+ """
+ return [0.0] * self._dof
+
+ def read_joint_efforts(self) -> list[float]:
+ """Read joint efforts/torques.
+
+ Note: Piper doesn't provide torque feedback by default.
+ """
+ return [0.0] * self._dof
+
+ def read_state(self) -> dict[str, int]:
+ """Read robot state."""
+ if not self._sdk:
+ return {"state": 0, "mode": 0}
+
+ try:
+ status = self._sdk.GetArmStatus()
+ if status and status.arm_status:
+ arm_status = status.arm_status
+ error_code = getattr(arm_status, "err_code", 0)
+ state = 2 if error_code != 0 else 0 # 2=error, 0=idle
+ return {
+ "state": state,
+ "mode": 0, # Piper doesn't expose mode
+ "error_code": error_code,
+ }
+ except Exception:
+ pass
+
+ return {"state": 0, "mode": 0}
+
+ def read_error(self) -> tuple[int, str]:
+ """Read error code and message."""
+ if not self._sdk:
+ return 0, ""
+
+ try:
+ status = self._sdk.GetArmStatus()
+ if status and status.arm_status:
+ error_code = getattr(status.arm_status, "err_code", 0)
+ if error_code == 0:
+ return 0, ""
+
+ # Piper error codes
+ error_map = {
+ 1: "Communication error",
+ 2: "Motor error",
+ 3: "Encoder error",
+ 4: "Overtemperature",
+ 5: "Overcurrent",
+ 6: "Joint limit error",
+ 7: "Emergency stop",
+ 8: "Power error",
+ }
+ return error_code, error_map.get(error_code, f"Unknown error {error_code}")
+ except Exception:
+ pass
+
+ return 0, ""
+
+ # =========================================================================
+ # Motion Control (Joint Space)
+ # =========================================================================
+
+ def write_joint_positions(
+ self,
+ positions: list[float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Write joint positions (radians -> Piper units).
+
+ Args:
+ positions: Target positions in radians
+ velocity: Speed as fraction of max (0-1)
+ """
+ if not self._sdk:
+ return False
+
+ # Convert radians to Piper units (0.001 degrees)
+ piper_joints = [round(rad * RAD_TO_PIPER) for rad in positions]
+
+ # Set speed rate if not full speed
+ if velocity < 1.0:
+ speed_rate = int(velocity * 100)
+ try:
+ self._sdk.MotionCtrl_2(
+ ctrl_mode=0x01,
+ move_mode=0x01,
+ move_spd_rate_ctrl=speed_rate,
+ is_mit_mode=0x00,
+ )
+ except Exception:
+ pass
+
+ try:
+ self._sdk.JointCtrl(
+ piper_joints[0],
+ piper_joints[1],
+ piper_joints[2],
+ piper_joints[3],
+ piper_joints[4],
+ piper_joints[5],
+ )
+ return True
+ except Exception as e:
+ print(f"Piper joint control error: {e}")
+ return False
+
+ def write_joint_velocities(self, velocities: list[float]) -> bool:
+ """Write joint velocities.
+
+ Note: Piper doesn't have native velocity control at SDK level.
+ Returns False - the driver should implement this via position integration.
+ """
+ return False
+
+ def write_stop(self) -> bool:
+ """Emergency stop."""
+ if not self._sdk:
+ return False
+
+ try:
+ if hasattr(self._sdk, "EmergencyStop"):
+ self._sdk.EmergencyStop()
+ return True
+ except Exception:
+ pass
+
+ # Fallback: disable arm
+ return self.write_enable(False)
+
+ # =========================================================================
+ # Servo Control
+ # =========================================================================
+
+ def write_enable(self, enable: bool) -> bool:
+ """Enable or disable servos."""
+ if not self._sdk:
+ return False
+
+ try:
+ if enable:
+ # Enable with retries (500ms max)
+ attempts = 0
+ max_attempts = 50
+ success = False
+ while attempts < max_attempts:
+ if self._sdk.EnablePiper():
+ success = True
+ break
+ time.sleep(0.01)
+ attempts += 1
+
+ if success:
+ self._enabled = True
+ # Set control mode
+ self._sdk.MotionCtrl_2(
+ ctrl_mode=0x01,
+ move_mode=0x01,
+ move_spd_rate_ctrl=30,
+ is_mit_mode=0x00,
+ )
+ return True
+ return False
+ else:
+ self._sdk.DisablePiper()
+ self._enabled = False
+ return True
+ except Exception:
+ return False
+
+ def read_enabled(self) -> bool:
+ """Check if servos are enabled."""
+ return self._enabled
+
+ def write_clear_errors(self) -> bool:
+ """Clear error state."""
+ if not self._sdk:
+ return False
+
+ try:
+ if hasattr(self._sdk, "ClearError"):
+ self._sdk.ClearError()
+ return True
+ except Exception:
+ pass
+
+ # Alternative: disable and re-enable
+ self.write_enable(False)
+ time.sleep(0.1)
+ return self.write_enable(True)
+
+ # =========================================================================
+ # Cartesian Control (Optional)
+ # =========================================================================
+
+ def read_cartesian_position(self) -> dict[str, float] | None:
+ """Read end-effector pose.
+
+ Note: Piper may not support direct cartesian feedback.
+ Returns None if not available.
+ """
+ if not self._sdk:
+ return None
+
+ try:
+ if hasattr(self._sdk, "GetArmEndPoseMsgs"):
+ pose_msgs = self._sdk.GetArmEndPoseMsgs()
+ if pose_msgs and pose_msgs.end_pose:
+ ep = pose_msgs.end_pose
+ return {
+ "x": ep.X_axis / 1000.0, # mm -> m
+ "y": ep.Y_axis / 1000.0,
+ "z": ep.Z_axis / 1000.0,
+ "roll": ep.RX_axis * PIPER_TO_RAD,
+ "pitch": ep.RY_axis * PIPER_TO_RAD,
+ "yaw": ep.RZ_axis * PIPER_TO_RAD,
+ }
+ except Exception:
+ pass
+
+ return None
+
+ def write_cartesian_position(
+ self,
+ pose: dict[str, float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Write end-effector pose.
+
+ Note: Piper may not support direct cartesian control.
+ """
+ # Cartesian control not commonly supported in Piper SDK
+ return False
+
+ # =========================================================================
+ # Gripper (Optional)
+ # =========================================================================
+
+ def read_gripper_position(self) -> float | None:
+ """Read gripper position (percentage -> meters)."""
+ if not self._sdk:
+ return None
+
+ try:
+ if hasattr(self._sdk, "GetArmGripperMsgs"):
+ gripper_msgs = self._sdk.GetArmGripperMsgs()
+ if gripper_msgs and gripper_msgs.gripper_state:
+ # Piper gripper position is 0-100 percentage
+ # Convert to meters (assume max opening 0.08m)
+ pos = gripper_msgs.gripper_state.grippers_angle
+ return float(pos / 100.0) * 0.08
+ except Exception:
+ pass
+
+ return None
+
+ def write_gripper_position(self, position: float) -> bool:
+ """Write gripper position (meters -> percentage)."""
+ if not self._sdk:
+ return False
+
+ try:
+ if hasattr(self._sdk, "GripperCtrl"):
+ # Convert meters to percentage (0-100)
+ # Assume max opening 0.08m
+ percentage = int((position / 0.08) * 100)
+ percentage = max(0, min(100, percentage))
+ self._sdk.GripperCtrl(percentage, 1000, 0x01, 0)
+ return True
+ except Exception:
+ pass
+
+ return False
+
+ # =========================================================================
+ # Force/Torque Sensor (Optional)
+ # =========================================================================
+
+ def read_force_torque(self) -> list[float] | None:
+ """Read F/T sensor data.
+
+ Note: Piper doesn't typically have F/T sensor.
+ """
+ return None
+
+
+__all__ = ["PiperBackend"]
diff --git a/dimos/hardware/manipulators/piper/can_activate.sh b/dimos/hardware/manipulators/piper/can_activate.sh
deleted file mode 100644
index addb892557..0000000000
--- a/dimos/hardware/manipulators/piper/can_activate.sh
+++ /dev/null
@@ -1,138 +0,0 @@
-#!/bin/bash
-
-# The default CAN name can be set by the user via command-line parameters.
-DEFAULT_CAN_NAME="${1:-can0}"
-
-# The default bitrate for a single CAN module can be set by the user via command-line parameters.
-DEFAULT_BITRATE="${2:-1000000}"
-
-# USB hardware address (optional parameter)
-USB_ADDRESS="${3}"
-echo "-------------------START-----------------------"
-# Check if ethtool is installed.
-if ! dpkg -l | grep -q "ethtool"; then
- echo "\e[31mError: ethtool not detected in the system.\e[0m"
- echo "Please use the following command to install ethtool:"
- echo "sudo apt update && sudo apt install ethtool"
- exit 1
-fi
-
-# Check if can-utils is installed.
-if ! dpkg -l | grep -q "can-utils"; then
- echo "\e[31mError: can-utils not detected in the system.\e[0m"
- echo "Please use the following command to install ethtool:"
- echo "sudo apt update && sudo apt install can-utils"
- exit 1
-fi
-
-echo "Both ethtool and can-utils are installed."
-
-# Retrieve the number of CAN modules in the current system.
-CURRENT_CAN_COUNT=$(ip link show type can | grep -c "link/can")
-
-# Verify if the number of CAN modules in the current system matches the expected value.
-if [ "$CURRENT_CAN_COUNT" -ne "1" ]; then
- if [ -z "$USB_ADDRESS" ]; then
- # Iterate through all CAN interfaces.
- for iface in $(ip -br link show type can | awk '{print $1}'); do
- # Use ethtool to retrieve bus-info.
- BUS_INFO=$(sudo ethtool -i "$iface" | grep "bus-info" | awk '{print $2}')
-
- if [ -z "$BUS_INFO" ];then
- echo "Error: Unable to retrieve bus-info for interface $iface."
- continue
- fi
-
- echo "Interface $iface is inserted into USB port $BUS_INFO"
- done
- echo -e " \e[31m Error: The number of CAN modules detected by the system ($CURRENT_CAN_COUNT) does not match the expected number (1). \e[0m"
- echo -e " \e[31m Please add the USB hardware address parameter, such as: \e[0m"
- echo -e " bash can_activate.sh can0 1000000 1-2:1.0"
- echo "-------------------ERROR-----------------------"
- exit 1
- fi
-fi
-
-# Load the gs_usb module.
-# sudo modprobe gs_usb
-# if [ $? -ne 0 ]; then
-# echo "Error: Unable to load the gs_usb module."
-# exit 1
-# fi
-
-if [ -n "$USB_ADDRESS" ]; then
- echo "Detected USB hardware address parameter: $USB_ADDRESS"
-
- # Use ethtool to find the CAN interface corresponding to the USB hardware address.
- INTERFACE_NAME=""
- for iface in $(ip -br link show type can | awk '{print $1}'); do
- BUS_INFO=$(sudo ethtool -i "$iface" | grep "bus-info" | awk '{print $2}')
- if [ "$BUS_INFO" = "$USB_ADDRESS" ]; then
- INTERFACE_NAME="$iface"
- break
- fi
- done
-
- if [ -z "$INTERFACE_NAME" ]; then
- echo "Error: Unable to find CAN interface corresponding to USB hardware address $USB_ADDRESS."
- exit 1
- else
- echo "Found the interface corresponding to USB hardware address $USB_ADDRESS: $INTERFACE_NAME."
- fi
-else
- # Retrieve the unique CAN interface.
- INTERFACE_NAME=$(ip -br link show type can | awk '{print $1}')
-
- # Check if the interface name has been retrieved.
- if [ -z "$INTERFACE_NAME" ]; then
- echo "Error: Unable to detect CAN interface."
- exit 1
- fi
- BUS_INFO=$(sudo ethtool -i "$INTERFACE_NAME" | grep "bus-info" | awk '{print $2}')
- echo "Expected to configure a single CAN module, detected interface $INTERFACE_NAME with corresponding USB address $BUS_INFO."
-fi
-
-# Check if the current interface is already activated.
-IS_LINK_UP=$(ip link show "$INTERFACE_NAME" | grep -q "UP" && echo "yes" || echo "no")
-
-# Retrieve the bitrate of the current interface.
-CURRENT_BITRATE=$(ip -details link show "$INTERFACE_NAME" | grep -oP 'bitrate \K\d+')
-
-if [ "$IS_LINK_UP" = "yes" ] && [ "$CURRENT_BITRATE" -eq "$DEFAULT_BITRATE" ]; then
- echo "Interface $INTERFACE_NAME is already activated with a bitrate of $DEFAULT_BITRATE."
-
- # Check if the interface name matches the default name.
- if [ "$INTERFACE_NAME" != "$DEFAULT_CAN_NAME" ]; then
- echo "Rename interface $INTERFACE_NAME to $DEFAULT_CAN_NAME."
- sudo ip link set "$INTERFACE_NAME" down
- sudo ip link set "$INTERFACE_NAME" name "$DEFAULT_CAN_NAME"
- sudo ip link set "$DEFAULT_CAN_NAME" up
- echo "The interface has been renamed to $DEFAULT_CAN_NAME and reactivated."
- else
- echo "The interface name is already $DEFAULT_CAN_NAME."
- fi
-else
- # If the interface is not activated or the bitrate is different, configure it.
- if [ "$IS_LINK_UP" = "yes" ]; then
- echo "Interface $INTERFACE_NAME is already activated, but the bitrate is $CURRENT_BITRATE, which does not match the set value of $DEFAULT_BITRATE."
- else
- echo "Interface $INTERFACE_NAME is not activated or bitrate is not set."
- fi
-
- # Set the interface bitrate and activate it.
- sudo ip link set "$INTERFACE_NAME" down
- sudo ip link set "$INTERFACE_NAME" type can bitrate $DEFAULT_BITRATE
- sudo ip link set "$INTERFACE_NAME" up
- echo "Interface $INTERFACE_NAME has been reset to bitrate $DEFAULT_BITRATE and activated."
-
- # Rename the interface to the default name.
- if [ "$INTERFACE_NAME" != "$DEFAULT_CAN_NAME" ]; then
- echo "Rename interface $INTERFACE_NAME to $DEFAULT_CAN_NAME."
- sudo ip link set "$INTERFACE_NAME" down
- sudo ip link set "$INTERFACE_NAME" name "$DEFAULT_CAN_NAME"
- sudo ip link set "$DEFAULT_CAN_NAME" up
- echo "The interface has been renamed to $DEFAULT_CAN_NAME and reactivated."
- fi
-fi
-
-echo "-------------------OVER------------------------"
diff --git a/dimos/hardware/manipulators/piper/components/__init__.py b/dimos/hardware/manipulators/piper/components/__init__.py
deleted file mode 100644
index 2c6d863ca1..0000000000
--- a/dimos/hardware/manipulators/piper/components/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-"""Component classes for PiperDriver."""
-
-from .configuration import ConfigurationComponent
-from .gripper_control import GripperControlComponent
-from .kinematics import KinematicsComponent
-from .motion_control import MotionControlComponent
-from .state_queries import StateQueryComponent
-from .system_control import SystemControlComponent
-
-__all__ = [
- "ConfigurationComponent",
- "GripperControlComponent",
- "KinematicsComponent",
- "MotionControlComponent",
- "StateQueryComponent",
- "SystemControlComponent",
-]
diff --git a/dimos/hardware/manipulators/piper/components/configuration.py b/dimos/hardware/manipulators/piper/components/configuration.py
deleted file mode 100644
index b7ac53c371..0000000000
--- a/dimos/hardware/manipulators/piper/components/configuration.py
+++ /dev/null
@@ -1,348 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Configuration Component for PiperDriver.
-
-Provides RPC methods for configuring robot parameters including:
-- Joint parameters (limits, speeds, acceleration)
-- End-effector parameters (speed, acceleration)
-- Collision protection
-- Motor configuration
-"""
-
-from typing import Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class ConfigurationComponent:
- """
- Component providing configuration RPC methods for PiperDriver.
-
- This component assumes the parent class has:
- - self.piper: C_PiperInterface_V2 instance
- - self.config: PiperDriverConfig instance
- """
-
- # Type hints for attributes provided by parent class
- piper: Any
- config: Any
-
- @rpc
- def set_joint_config(
- self,
- motor_num: int,
- kp_factor: int,
- ki_factor: int,
- kd_factor: int,
- ke_factor: int = 0,
- ) -> tuple[bool, str]:
- """
- Configure joint control parameters.
-
- Args:
- motor_num: Motor number (1-6)
- kp_factor: Proportional gain factor
- ki_factor: Integral gain factor
- kd_factor: Derivative gain factor
- ke_factor: Error gain factor
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if motor_num not in range(1, 7):
- return (False, f"Invalid motor_num: {motor_num}. Must be 1-6")
-
- result = self.piper.JointConfig(motor_num, kp_factor, ki_factor, kd_factor, ke_factor)
-
- if result:
- return (True, f"Joint {motor_num} configuration set successfully")
- else:
- return (False, f"Failed to configure joint {motor_num}")
-
- except Exception as e:
- logger.error(f"set_joint_config failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_joint_max_acc(self, motor_num: int, max_joint_acc: int) -> tuple[bool, str]:
- """
- Set joint maximum acceleration.
-
- Args:
- motor_num: Motor number (1-6)
- max_joint_acc: Maximum joint acceleration
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if motor_num not in range(1, 7):
- return (False, f"Invalid motor_num: {motor_num}. Must be 1-6")
-
- result = self.piper.JointMaxAccConfig(motor_num, max_joint_acc)
-
- if result:
- return (True, f"Joint {motor_num} max acceleration set to {max_joint_acc}")
- else:
- return (False, f"Failed to set max acceleration for joint {motor_num}")
-
- except Exception as e:
- logger.error(f"set_joint_max_acc failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_motor_angle_limit_max_speed(
- self,
- motor_num: int,
- min_joint_angle: int,
- max_joint_angle: int,
- max_joint_speed: int,
- ) -> tuple[bool, str]:
- """
- Set motor angle limits and maximum speed.
-
- Args:
- motor_num: Motor number (1-6)
- min_joint_angle: Minimum joint angle (in Piper units: 0.001 degrees)
- max_joint_angle: Maximum joint angle (in Piper units: 0.001 degrees)
- max_joint_speed: Maximum joint speed
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if motor_num not in range(1, 7):
- return (False, f"Invalid motor_num: {motor_num}. Must be 1-6")
-
- result = self.piper.MotorAngleLimitMaxSpdSet(
- motor_num, min_joint_angle, max_joint_angle, max_joint_speed
- )
-
- if result:
- return (
- True,
- f"Joint {motor_num} angle limits and max speed set successfully",
- )
- else:
- return (False, f"Failed to set angle limits for joint {motor_num}")
-
- except Exception as e:
- logger.error(f"set_motor_angle_limit_max_speed failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_motor_max_speed(self, motor_num: int, max_joint_spd: int) -> tuple[bool, str]:
- """
- Set motor maximum speed.
-
- Args:
- motor_num: Motor number (1-6)
- max_joint_spd: Maximum joint speed
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if motor_num not in range(1, 7):
- return (False, f"Invalid motor_num: {motor_num}. Must be 1-6")
-
- result = self.piper.MotorMaxSpdSet(motor_num, max_joint_spd)
-
- if result:
- return (True, f"Joint {motor_num} max speed set to {max_joint_spd}")
- else:
- return (False, f"Failed to set max speed for joint {motor_num}")
-
- except Exception as e:
- logger.error(f"set_motor_max_speed failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_end_speed_and_acc(
- self,
- end_max_linear_vel: int,
- end_max_angular_vel: int,
- end_max_linear_acc: int,
- end_max_angular_acc: int,
- ) -> tuple[bool, str]:
- """
- Set end-effector speed and acceleration parameters.
-
- Args:
- end_max_linear_vel: Maximum linear velocity
- end_max_angular_vel: Maximum angular velocity
- end_max_linear_acc: Maximum linear acceleration
- end_max_angular_acc: Maximum angular acceleration
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.EndSpdAndAccParamSet(
- end_max_linear_vel,
- end_max_angular_vel,
- end_max_linear_acc,
- end_max_angular_acc,
- )
-
- if result:
- return (True, "End-effector speed and acceleration parameters set successfully")
- else:
- return (False, "Failed to set end-effector parameters")
-
- except Exception as e:
- logger.error(f"set_end_speed_and_acc failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_crash_protection_level(self, level: int) -> tuple[bool, str]:
- """
- Set collision/crash protection level.
-
- Args:
- level: Protection level (0=disabled, higher values = more sensitive)
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.CrashProtectionConfig(level)
-
- if result:
- return (True, f"Crash protection level set to {level}")
- else:
- return (False, "Failed to set crash protection level")
-
- except Exception as e:
- logger.error(f"set_crash_protection_level failed: {e}")
- return (False, str(e))
-
- @rpc
- def search_motor_max_angle_speed_acc_limit(self, motor_num: int) -> tuple[bool, str]:
- """
- Search for motor maximum angle, speed, and acceleration limits.
-
- Args:
- motor_num: Motor number (1-6)
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if motor_num not in range(1, 7):
- return (False, f"Invalid motor_num: {motor_num}. Must be 1-6")
-
- result = self.piper.SearchMotorMaxAngleSpdAccLimit(motor_num)
-
- if result:
- return (True, f"Search initiated for motor {motor_num} limits")
- else:
- return (False, f"Failed to search limits for motor {motor_num}")
-
- except Exception as e:
- logger.error(f"search_motor_max_angle_speed_acc_limit failed: {e}")
- return (False, str(e))
-
- @rpc
- def search_all_motor_max_angle_speed(self) -> tuple[bool, str]:
- """
- Search for all motors' maximum angle and speed limits.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.SearchAllMotorMaxAngleSpd()
-
- if result:
- return (True, "Search initiated for all motor angle/speed limits")
- else:
- return (False, "Failed to search all motor limits")
-
- except Exception as e:
- logger.error(f"search_all_motor_max_angle_speed failed: {e}")
- return (False, str(e))
-
- @rpc
- def search_all_motor_max_acc_limit(self) -> tuple[bool, str]:
- """
- Search for all motors' maximum acceleration limits.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.SearchAllMotorMaxAccLimit()
-
- if result:
- return (True, "Search initiated for all motor acceleration limits")
- else:
- return (False, "Failed to search all motor acceleration limits")
-
- except Exception as e:
- logger.error(f"search_all_motor_max_acc_limit failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_sdk_joint_limit_param(
- self, joint_limits: list[tuple[float, float]]
- ) -> tuple[bool, str]:
- """
- Set SDK joint limit parameters.
-
- Args:
- joint_limits: List of (min_angle, max_angle) tuples for each joint in radians
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if len(joint_limits) != 6:
- return (False, f"Expected 6 joint limit tuples, got {len(joint_limits)}")
-
- # Convert to Piper units and call SDK method
- # Note: Actual SDK method signature may vary
- logger.info(f"Setting SDK joint limits: {joint_limits}")
- return (True, "SDK joint limits set (method may vary by SDK version)")
-
- except Exception as e:
- logger.error(f"set_sdk_joint_limit_param failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_sdk_gripper_range_param(self, min_range: int, max_range: int) -> tuple[bool, str]:
- """
- Set SDK gripper range parameters.
-
- Args:
- min_range: Minimum gripper range
- max_range: Maximum gripper range
-
- Returns:
- Tuple of (success, message)
- """
- try:
- # Note: Actual SDK method signature may vary
- logger.info(f"Setting SDK gripper range: {min_range} - {max_range}")
- return (True, "SDK gripper range set (method may vary by SDK version)")
-
- except Exception as e:
- logger.error(f"set_sdk_gripper_range_param failed: {e}")
- return (False, str(e))
diff --git a/dimos/hardware/manipulators/piper/components/gripper_control.py b/dimos/hardware/manipulators/piper/components/gripper_control.py
deleted file mode 100644
index 5f500097cd..0000000000
--- a/dimos/hardware/manipulators/piper/components/gripper_control.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Gripper Control Component for PiperDriver.
-
-Provides RPC methods for gripper control operations.
-"""
-
-from typing import Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class GripperControlComponent:
- """
- Component providing gripper control RPC methods for PiperDriver.
-
- This component assumes the parent class has:
- - self.piper: C_PiperInterface_V2 instance
- - self.config: PiperDriverConfig instance
- """
-
- # Type hints for attributes provided by parent class
- piper: Any
- config: Any
-
- @rpc
- def set_gripper(
- self,
- gripper_angle: int,
- gripper_effort: int = 100,
- gripper_enable: int = 0x01,
- gripper_state: int = 0x00,
- ) -> tuple[bool, str]:
- """
- Set gripper position and parameters.
-
- Args:
- gripper_angle: Gripper angle (0-1000, 0=closed, 1000=open)
- gripper_effort: Gripper effort/force (0-1000)
- gripper_enable: Gripper enable (0x00=disabled, 0x01=enabled)
- gripper_state: Gripper state
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.GripperCtrl(
- gripper_angle, gripper_effort, gripper_enable, gripper_state
- )
-
- if result:
- return (True, f"Gripper set to angle={gripper_angle}, effort={gripper_effort}")
- else:
- return (False, "Failed to set gripper")
-
- except Exception as e:
- logger.error(f"set_gripper failed: {e}")
- return (False, str(e))
-
- @rpc
- def open_gripper(self, effort: int = 100) -> tuple[bool, str]:
- """
- Open gripper.
-
- Args:
- effort: Gripper effort (0-1000)
-
- Returns:
- Tuple of (success, message)
- """
- result: tuple[bool, str] = self.set_gripper(gripper_angle=1000, gripper_effort=effort)
- return result
-
- @rpc
- def close_gripper(self, effort: int = 100) -> tuple[bool, str]:
- """
- Close gripper.
-
- Args:
- effort: Gripper effort (0-1000)
-
- Returns:
- Tuple of (success, message)
- """
- result: tuple[bool, str] = self.set_gripper(gripper_angle=0, gripper_effort=effort)
- return result
-
- @rpc
- def set_gripper_zero(self) -> tuple[bool, str]:
- """
- Set gripper zero position.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- # This method may require specific SDK implementation
- # For now, we'll just document it
- logger.info("set_gripper_zero called - implementation may vary by SDK version")
- return (True, "Gripper zero set (if supported by SDK)")
-
- except Exception as e:
- logger.error(f"set_gripper_zero failed: {e}")
- return (False, str(e))
diff --git a/dimos/hardware/manipulators/piper/components/kinematics.py b/dimos/hardware/manipulators/piper/components/kinematics.py
deleted file mode 100644
index 51be97a764..0000000000
--- a/dimos/hardware/manipulators/piper/components/kinematics.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Kinematics Component for PiperDriver.
-
-Provides RPC methods for kinematic calculations including:
-- Forward kinematics
-"""
-
-from typing import Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class KinematicsComponent:
- """
- Component providing kinematics RPC methods for PiperDriver.
-
- This component assumes the parent class has:
- - self.piper: C_PiperInterface_V2 instance
- - self.config: PiperDriverConfig instance
- - PIPER_TO_RAD: conversion constant (0.001 degrees → radians)
- """
-
- # Type hints for attributes provided by parent class
- piper: Any
- config: Any
-
- @rpc
- def get_forward_kinematics(
- self, mode: str = "feedback"
- ) -> tuple[bool, dict[str, float] | None]:
- """
- Compute forward kinematics.
-
- Args:
- mode: "feedback" for current joint angles, "control" for commanded angles
-
- Returns:
- Tuple of (success, pose_dict) with keys: x, y, z, rx, ry, rz
- """
- try:
- fk_result = self.piper.GetFK(mode=mode)
-
- if fk_result is not None:
- # Convert from Piper units
- pose_dict = {
- "x": fk_result[0] * 0.001, # 0.001 mm → mm
- "y": fk_result[1] * 0.001,
- "z": fk_result[2] * 0.001,
- "rx": fk_result[3] * 0.001 * (3.14159 / 180.0), # → rad
- "ry": fk_result[4] * 0.001 * (3.14159 / 180.0),
- "rz": fk_result[5] * 0.001 * (3.14159 / 180.0),
- }
- return (True, pose_dict)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_forward_kinematics failed: {e}")
- return (False, None)
-
- @rpc
- def enable_fk_calculation(self) -> tuple[bool, str]:
- """
- Enable forward kinematics calculation.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.EnableFkCal()
-
- if result:
- return (True, "FK calculation enabled")
- else:
- return (False, "Failed to enable FK calculation")
-
- except Exception as e:
- logger.error(f"enable_fk_calculation failed: {e}")
- return (False, str(e))
-
- @rpc
- def disable_fk_calculation(self) -> tuple[bool, str]:
- """
- Disable forward kinematics calculation.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.DisableFkCal()
-
- if result:
- return (True, "FK calculation disabled")
- else:
- return (False, "Failed to disable FK calculation")
-
- except Exception as e:
- logger.error(f"disable_fk_calculation failed: {e}")
- return (False, str(e))
diff --git a/dimos/hardware/manipulators/piper/components/motion_control.py b/dimos/hardware/manipulators/piper/components/motion_control.py
deleted file mode 100644
index 7a0dc36eed..0000000000
--- a/dimos/hardware/manipulators/piper/components/motion_control.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Motion Control Component for PiperDriver.
-
-Provides RPC methods for motion control operations including:
-- Joint position control
-- Joint velocity control
-- End-effector pose control
-- Emergency stop
-- Circular motion
-"""
-
-import math
-import time
-from typing import Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class MotionControlComponent:
- """
- Component providing motion control RPC methods for PiperDriver.
-
- This component assumes the parent class has:
- - self.piper: C_PiperInterface_V2 instance
- - self.config: PiperDriverConfig instance
- - RAD_TO_PIPER: conversion constant (radians → 0.001 degrees)
- - PIPER_TO_RAD: conversion constant (0.001 degrees → radians)
- """
-
- # Type hints for attributes expected from parent class
- piper: Any
- config: Any
- RAD_TO_PIPER: float
- PIPER_TO_RAD: float
- _joint_cmd_lock: Any
- _joint_cmd_: Any
- _vel_cmd_: Any
- _last_cmd_time: float
-
- @rpc
- def set_joint_angles(self, angles: list[float], gripper_state: int = 0x00) -> tuple[bool, str]:
- """
- Set joint angles (RPC method).
-
- Args:
- angles: List of joint angles in radians
- gripper_state: Gripper state (0x00 = no change, 0x01 = open, 0x02 = close)
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if len(angles) != 6:
- return (False, f"Expected 6 joint angles, got {len(angles)}")
-
- # Convert radians to Piper units (0.001 degrees)
- piper_joints = [round(rad * self.RAD_TO_PIPER) for rad in angles]
-
- # Send joint control command
- result = self.piper.JointCtrl(
- piper_joints[0],
- piper_joints[1],
- piper_joints[2],
- piper_joints[3],
- piper_joints[4],
- piper_joints[5],
- gripper_state,
- )
-
- if result:
- return (True, "Joint angles set successfully")
- else:
- return (False, "Failed to set joint angles")
-
- except Exception as e:
- logger.error(f"set_joint_angles failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_joint_command(self, positions: list[float]) -> tuple[bool, str]:
- """
- Manually set the joint command (for testing).
- This updates the shared joint_cmd that the control loop reads.
-
- Args:
- positions: List of joint positions in radians
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if len(positions) != 6:
- return (False, f"Expected 6 joint positions, got {len(positions)}")
-
- with self._joint_cmd_lock:
- self._joint_cmd_ = list(positions)
-
- logger.info(f"✓ Joint command set: {[f'{math.degrees(p):.2f}°' for p in positions]}")
- return (True, "Joint command updated")
- except Exception as e:
- return (False, str(e))
-
- @rpc
- def set_end_pose(
- self, x: float, y: float, z: float, rx: float, ry: float, rz: float
- ) -> tuple[bool, str]:
- """
- Set end-effector pose.
-
- Args:
- x: X position in millimeters
- y: Y position in millimeters
- z: Z position in millimeters
- rx: Roll in radians
- ry: Pitch in radians
- rz: Yaw in radians
-
- Returns:
- Tuple of (success, message)
- """
- try:
- # Convert to Piper units
- # Position: mm → 0.001 mm
- x_piper = round(x * 1000)
- y_piper = round(y * 1000)
- z_piper = round(z * 1000)
-
- # Rotation: radians → 0.001 degrees
- rx_piper = round(math.degrees(rx) * 1000)
- ry_piper = round(math.degrees(ry) * 1000)
- rz_piper = round(math.degrees(rz) * 1000)
-
- # Send end pose control command
- result = self.piper.EndPoseCtrl(x_piper, y_piper, z_piper, rx_piper, ry_piper, rz_piper)
-
- if result:
- return (True, "End pose set successfully")
- else:
- return (False, "Failed to set end pose")
-
- except Exception as e:
- logger.error(f"set_end_pose failed: {e}")
- return (False, str(e))
-
- @rpc
- def emergency_stop(self) -> tuple[bool, str]:
- """Emergency stop the arm."""
- try:
- result = self.piper.EmergencyStop()
-
- if result:
- logger.warning("Emergency stop activated")
- return (True, "Emergency stop activated")
- else:
- return (False, "Failed to activate emergency stop")
-
- except Exception as e:
- logger.error(f"emergency_stop failed: {e}")
- return (False, str(e))
-
- @rpc
- def move_c_axis_update(self, instruction_num: int = 0x00) -> tuple[bool, str]:
- """
- Update circular motion axis.
-
- Args:
- instruction_num: Instruction number (0x00, 0x01, 0x02, 0x03)
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if instruction_num not in [0x00, 0x01, 0x02, 0x03]:
- return (False, f"Invalid instruction_num: {instruction_num}")
-
- result = self.piper.MoveCAxisUpdateCtrl(instruction_num)
-
- if result:
- return (True, f"Move C axis updated with instruction {instruction_num}")
- else:
- return (False, "Failed to update Move C axis")
-
- except Exception as e:
- logger.error(f"move_c_axis_update failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_joint_mit_ctrl(
- self,
- motor_num: int,
- pos_target: float,
- vel_target: float,
- torq_target: float,
- kp: int,
- kd: int,
- ) -> tuple[bool, str]:
- """
- Set joint MIT (Model-based Inverse Torque) control.
-
- Args:
- motor_num: Motor number (1-6)
- pos_target: Target position in radians
- vel_target: Target velocity in rad/s
- torq_target: Target torque in Nm
- kp: Proportional gain (0-100)
- kd: Derivative gain (0-100)
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if motor_num not in range(1, 7):
- return (False, f"Invalid motor_num: {motor_num}. Must be 1-6")
-
- # Convert to Piper units
- pos_piper = round(pos_target * self.RAD_TO_PIPER)
- vel_piper = round(vel_target * self.RAD_TO_PIPER)
- torq_piper = round(torq_target * 1000) # Torque in millinewton-meters
-
- result = self.piper.JointMitCtrl(motor_num, pos_piper, vel_piper, torq_piper, kp, kd)
-
- if result:
- return (True, f"Joint {motor_num} MIT control set successfully")
- else:
- return (False, f"Failed to set MIT control for joint {motor_num}")
-
- except Exception as e:
- logger.error(f"set_joint_mit_ctrl failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_joint_velocities(self, velocities: list[float]) -> tuple[bool, str]:
- """
- Set joint velocities (RPC method).
-
- Requires velocity control mode to be enabled.
-
- The control loop integrates velocities to positions:
- - position_target += velocity * dt
- - Integrated positions are sent to JointCtrl
-
- This provides smooth velocity control while using the proven position API.
-
- Args:
- velocities: List of 6 joint velocities in rad/s
-
- Returns:
- Tuple of (success, message)
- """
- try:
- if len(velocities) != 6:
- return (False, f"Expected 6 velocities, got {len(velocities)}")
-
- if not self.config.velocity_control:
- return (
- False,
- "Velocity control mode not enabled. Call enable_velocity_control_mode() first.",
- )
-
- with self._joint_cmd_lock:
- self._vel_cmd_ = list(velocities)
- self._last_cmd_time = time.time()
-
- logger.info(f"✓ Velocity command set: {[f'{v:.3f} rad/s' for v in velocities]}")
- return (True, "Velocity command updated")
-
- except Exception as e:
- logger.error(f"set_joint_velocities failed: {e}")
- return (False, str(e))
diff --git a/dimos/hardware/manipulators/piper/components/state_queries.py b/dimos/hardware/manipulators/piper/components/state_queries.py
deleted file mode 100644
index 3fe00fffc6..0000000000
--- a/dimos/hardware/manipulators/piper/components/state_queries.py
+++ /dev/null
@@ -1,340 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-State Query Component for PiperDriver.
-
-Provides RPC methods for querying robot state including:
-- Joint state
-- Robot state
-- End-effector pose
-- Gripper state
-- Motor information
-- Firmware version
-"""
-
-import threading
-from typing import Any
-
-from dimos.core import rpc
-from dimos.msgs.sensor_msgs import JointState, RobotState
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class StateQueryComponent:
- """
- Component providing state query RPC methods for PiperDriver.
-
- This component assumes the parent class has:
- - self.piper: C_PiperInterface_V2 instance
- - self.config: PiperDriverConfig instance
- - self._joint_state_lock: threading.Lock
- - self._joint_states_: Optional[JointState]
- - self._robot_state_: Optional[RobotState]
- - PIPER_TO_RAD: conversion constant (0.001 degrees → radians)
- """
-
- # Type hints for attributes expected from parent class
- piper: Any # C_PiperInterface_V2 instance
- config: Any # Config dict accessed as object
- _joint_state_lock: threading.Lock
- _joint_states_: JointState | None
- _robot_state_: RobotState | None
- PIPER_TO_RAD: float
-
- @rpc
- def get_joint_state(self) -> JointState | None:
- """
- Get the current joint state (RPC method).
-
- Returns:
- Current JointState or None
- """
- with self._joint_state_lock:
- return self._joint_states_
-
- @rpc
- def get_robot_state(self) -> RobotState | None:
- """
- Get the current robot state (RPC method).
-
- Returns:
- Current RobotState or None
- """
- with self._joint_state_lock:
- return self._robot_state_
-
- @rpc
- def get_arm_status(self) -> tuple[bool, dict[str, Any] | None]:
- """
- Get arm status.
-
- Returns:
- Tuple of (success, status_dict)
- """
- try:
- status = self.piper.GetArmStatus()
-
- if status is not None:
- status_dict = {
- "time_stamp": status.time_stamp,
- "Hz": status.Hz,
- "motion_mode": status.arm_status.motion_mode,
- "mode_feedback": status.arm_status.mode_feedback,
- "teach_status": status.arm_status.teach_status,
- "motion_status": status.arm_status.motion_status,
- "trajectory_num": status.arm_status.trajectory_num,
- }
- return (True, status_dict)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_arm_status failed: {e}")
- return (False, None)
-
- @rpc
- def get_arm_joint_angles(self) -> tuple[bool, list[float] | None]:
- """
- Get arm joint angles in radians.
-
- Returns:
- Tuple of (success, joint_angles)
- """
- try:
- arm_joint = self.piper.GetArmJointMsgs()
-
- if arm_joint is not None:
- # Convert from Piper units (0.001 degrees) to radians
- angles = [
- arm_joint.joint_state.joint_1 * self.PIPER_TO_RAD,
- arm_joint.joint_state.joint_2 * self.PIPER_TO_RAD,
- arm_joint.joint_state.joint_3 * self.PIPER_TO_RAD,
- arm_joint.joint_state.joint_4 * self.PIPER_TO_RAD,
- arm_joint.joint_state.joint_5 * self.PIPER_TO_RAD,
- arm_joint.joint_state.joint_6 * self.PIPER_TO_RAD,
- ]
- return (True, angles)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_arm_joint_angles failed: {e}")
- return (False, None)
-
- @rpc
- def get_end_pose(self) -> tuple[bool, dict[str, float] | None]:
- """
- Get end-effector pose.
-
- Returns:
- Tuple of (success, pose_dict) with keys: x, y, z, rx, ry, rz
- """
- try:
- end_pose = self.piper.GetArmEndPoseMsgs()
-
- if end_pose is not None:
- # Convert from Piper units
- pose_dict = {
- "x": end_pose.end_pose.end_pose_x * 0.001, # 0.001 mm → mm
- "y": end_pose.end_pose.end_pose_y * 0.001,
- "z": end_pose.end_pose.end_pose_z * 0.001,
- "rx": end_pose.end_pose.end_pose_rx * 0.001 * (3.14159 / 180.0), # → rad
- "ry": end_pose.end_pose.end_pose_ry * 0.001 * (3.14159 / 180.0),
- "rz": end_pose.end_pose.end_pose_rz * 0.001 * (3.14159 / 180.0),
- "time_stamp": end_pose.time_stamp,
- "Hz": end_pose.Hz,
- }
- return (True, pose_dict)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_end_pose failed: {e}")
- return (False, None)
-
- @rpc
- def get_gripper_state(self) -> tuple[bool, dict[str, Any] | None]:
- """
- Get gripper state.
-
- Returns:
- Tuple of (success, gripper_dict)
- """
- try:
- gripper = self.piper.GetArmGripperMsgs()
-
- if gripper is not None:
- gripper_dict = {
- "gripper_angle": gripper.gripper_state.grippers_angle,
- "gripper_effort": gripper.gripper_state.grippers_effort,
- "gripper_enable": gripper.gripper_state.grippers_enabled,
- "time_stamp": gripper.time_stamp,
- "Hz": gripper.Hz,
- }
- return (True, gripper_dict)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_gripper_state failed: {e}")
- return (False, None)
-
- @rpc
- def get_arm_enable_status(self) -> tuple[bool, list[int] | None]:
- """
- Get arm enable status for all joints.
-
- Returns:
- Tuple of (success, enable_status_list)
- """
- try:
- enable_status = self.piper.GetArmEnableStatus()
-
- if enable_status is not None:
- return (True, enable_status)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_arm_enable_status failed: {e}")
- return (False, None)
-
- @rpc
- def get_firmware_version(self) -> tuple[bool, str | None]:
- """
- Get Piper firmware version.
-
- Returns:
- Tuple of (success, version_string)
- """
- try:
- version = self.piper.GetPiperFirmwareVersion()
-
- if version is not None:
- return (True, version)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_firmware_version failed: {e}")
- return (False, None)
-
- @rpc
- def get_sdk_version(self) -> tuple[bool, str | None]:
- """
- Get Piper SDK version.
-
- Returns:
- Tuple of (success, version_string)
- """
- try:
- version = self.piper.GetCurrentSDKVersion()
-
- if version is not None:
- return (True, version)
- else:
- return (False, None)
-
- except Exception:
- return (False, None)
-
- @rpc
- def get_interface_version(self) -> tuple[bool, str | None]:
- """
- Get Piper interface version.
-
- Returns:
- Tuple of (success, version_string)
- """
- try:
- version = self.piper.GetCurrentInterfaceVersion()
-
- if version is not None:
- return (True, version)
- else:
- return (False, None)
-
- except Exception:
- return (False, None)
-
- @rpc
- def get_protocol_version(self) -> tuple[bool, str | None]:
- """
- Get Piper protocol version.
-
- Returns:
- Tuple of (success, version_string)
- """
- try:
- version = self.piper.GetCurrentProtocolVersion()
-
- if version is not None:
- return (True, version)
- else:
- return (False, None)
-
- except Exception:
- return (False, None)
-
- @rpc
- def get_can_fps(self) -> tuple[bool, float | None]:
- """
- Get CAN bus FPS (frames per second).
-
- Returns:
- Tuple of (success, fps_value)
- """
- try:
- fps = self.piper.GetCanFps()
-
- if fps is not None:
- return (True, fps)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_can_fps failed: {e}")
- return (False, None)
-
- @rpc
- def get_motor_max_acc_limit(self) -> tuple[bool, dict[str, Any] | None]:
- """
- Get maximum acceleration limit for all motors.
-
- Returns:
- Tuple of (success, acc_limit_dict)
- """
- try:
- acc_limit = self.piper.GetCurrentMotorMaxAccLimit()
-
- if acc_limit is not None:
- acc_dict = {
- "motor_1": acc_limit.current_motor_max_acc_limit.motor_1_max_acc_limit,
- "motor_2": acc_limit.current_motor_max_acc_limit.motor_2_max_acc_limit,
- "motor_3": acc_limit.current_motor_max_acc_limit.motor_3_max_acc_limit,
- "motor_4": acc_limit.current_motor_max_acc_limit.motor_4_max_acc_limit,
- "motor_5": acc_limit.current_motor_max_acc_limit.motor_5_max_acc_limit,
- "motor_6": acc_limit.current_motor_max_acc_limit.motor_6_max_acc_limit,
- "time_stamp": acc_limit.time_stamp,
- }
- return (True, acc_dict)
- else:
- return (False, None)
-
- except Exception as e:
- logger.error(f"get_motor_max_acc_limit failed: {e}")
- return (False, None)
diff --git a/dimos/hardware/manipulators/piper/components/system_control.py b/dimos/hardware/manipulators/piper/components/system_control.py
deleted file mode 100644
index a15eb29133..0000000000
--- a/dimos/hardware/manipulators/piper/components/system_control.py
+++ /dev/null
@@ -1,395 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-System Control Component for PiperDriver.
-
-Provides RPC methods for system-level control operations including:
-- Enable/disable arm
-- Mode control (drag teach, MIT control, etc.)
-- Motion control
-- Master/slave configuration
-"""
-
-from typing import Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class SystemControlComponent:
- """
- Component providing system control RPC methods for PiperDriver.
-
- This component assumes the parent class has:
- - self.piper: C_PiperInterface_V2 instance
- - self.config: PiperDriverConfig instance
- """
-
- # Type hints for attributes expected from parent class
- piper: Any # C_PiperInterface_V2 instance
- config: Any # Config dict accessed as object
-
- @rpc
- def enable_servo_mode(self) -> tuple[bool, str]:
- """
- Enable servo mode.
- This enables the arm to receive motion commands.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.EnableArm()
-
- if result:
- logger.info("Servo mode enabled")
- return (True, "Servo mode enabled")
- else:
- logger.warning("Failed to enable servo mode")
- return (False, "Failed to enable servo mode")
-
- except Exception as e:
- logger.error(f"enable_servo_mode failed: {e}")
- return (False, str(e))
-
- @rpc
- def disable_servo_mode(self) -> tuple[bool, str]:
- """
- Disable servo mode.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.DisableArm()
-
- if result:
- logger.info("Servo mode disabled")
- return (True, "Servo mode disabled")
- else:
- logger.warning("Failed to disable servo mode")
- return (False, "Failed to disable servo mode")
-
- except Exception as e:
- logger.error(f"disable_servo_mode failed: {e}")
- return (False, str(e))
-
- @rpc
- def motion_enable(self, enable: bool = True) -> tuple[bool, str]:
- """Enable or disable arm motion."""
- try:
- if enable:
- result = self.piper.EnableArm()
- msg = "Motion enabled"
- else:
- result = self.piper.DisableArm()
- msg = "Motion disabled"
-
- if result:
- return (True, msg)
- else:
- return (False, f"Failed to {msg.lower()}")
-
- except Exception as e:
- return (False, str(e))
-
- @rpc
- def set_motion_ctrl_1(
- self,
- ctrl_mode: int = 0x00,
- move_mode: int = 0x00,
- move_spd_rate: int = 50,
- coor_mode: int = 0x00,
- reference_joint: int = 0x00,
- ) -> tuple[bool, str]:
- """
- Set motion control parameters (MotionCtrl_1).
-
- Args:
- ctrl_mode: Control mode
- move_mode: Movement mode
- move_spd_rate: Movement speed rate (0-100)
- coor_mode: Coordinate mode
- reference_joint: Reference joint
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.MotionCtrl_1(
- ctrl_mode, move_mode, move_spd_rate, coor_mode, reference_joint
- )
-
- if result:
- return (True, "Motion control 1 parameters set successfully")
- else:
- return (False, "Failed to set motion control 1 parameters")
-
- except Exception as e:
- logger.error(f"set_motion_ctrl_1 failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_motion_ctrl_2(
- self,
- limit_fun_en: int = 0x00,
- collis_detect_en: int = 0x00,
- friction_feed_en: int = 0x00,
- gravity_feed_en: int = 0x00,
- is_mit_mode: int = 0x00,
- ) -> tuple[bool, str]:
- """
- Set motion control parameters (MotionCtrl_2).
-
- Args:
- limit_fun_en: Limit function enable (0x00 = disabled, 0x01 = enabled)
- collis_detect_en: Collision detection enable
- friction_feed_en: Friction compensation enable
- gravity_feed_en: Gravity compensation enable
- is_mit_mode: MIT mode enable (0x00 = disabled, 0x01 = enabled)
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.MotionCtrl_2(
- limit_fun_en,
- collis_detect_en,
- friction_feed_en,
- gravity_feed_en,
- is_mit_mode,
- )
-
- if result:
- return (True, "Motion control 2 parameters set successfully")
- else:
- return (False, "Failed to set motion control 2 parameters")
-
- except Exception as e:
- logger.error(f"set_motion_ctrl_2 failed: {e}")
- return (False, str(e))
-
- @rpc
- def set_mode_ctrl(
- self,
- drag_teach_en: int = 0x00,
- teach_record_en: int = 0x00,
- ) -> tuple[bool, str]:
- """
- Set mode control (drag teaching, recording, etc.).
-
- Args:
- drag_teach_en: Drag teaching enable (0x00 = disabled, 0x01 = enabled)
- teach_record_en: Teaching record enable
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.ModeCtrl(drag_teach_en, teach_record_en)
-
- if result:
- mode_str = []
- if drag_teach_en == 0x01:
- mode_str.append("drag teaching")
- if teach_record_en == 0x01:
- mode_str.append("recording")
-
- if mode_str:
- return (True, f"Mode control set: {', '.join(mode_str)} enabled")
- else:
- return (True, "Mode control set: all modes disabled")
- else:
- return (False, "Failed to set mode control")
-
- except Exception as e:
- logger.error(f"set_mode_ctrl failed: {e}")
- return (False, str(e))
-
- @rpc
- def configure_master_slave(
- self,
- linkage_config: int,
- feedback_offset: int,
- ctrl_offset: int,
- linkage_offset: int,
- ) -> tuple[bool, str]:
- """
- Configure master/slave linkage.
-
- Args:
- linkage_config: Linkage configuration
- feedback_offset: Feedback offset
- ctrl_offset: Control offset
- linkage_offset: Linkage offset
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.MasterSlaveConfig(
- linkage_config, feedback_offset, ctrl_offset, linkage_offset
- )
-
- if result:
- return (True, "Master/slave configuration set successfully")
- else:
- return (False, "Failed to set master/slave configuration")
-
- except Exception as e:
- logger.error(f"configure_master_slave failed: {e}")
- return (False, str(e))
-
- @rpc
- def search_firmware_version(self) -> tuple[bool, str]:
- """
- Search for firmware version.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.SearchPiperFirmwareVersion()
-
- if result:
- return (True, "Firmware version search initiated")
- else:
- return (False, "Failed to search firmware version")
-
- except Exception as e:
- logger.error(f"search_firmware_version failed: {e}")
- return (False, str(e))
-
- @rpc
- def piper_init(self) -> tuple[bool, str]:
- """
- Initialize Piper arm.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.PiperInit()
-
- if result:
- logger.info("Piper initialized")
- return (True, "Piper initialized successfully")
- else:
- logger.warning("Failed to initialize Piper")
- return (False, "Failed to initialize Piper")
-
- except Exception as e:
- logger.error(f"piper_init failed: {e}")
- return (False, str(e))
-
- @rpc
- def enable_piper(self) -> tuple[bool, str]:
- """
- Enable Piper (convenience method).
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.EnablePiper()
-
- if result:
- logger.info("Piper enabled")
- return (True, "Piper enabled")
- else:
- logger.warning("Failed to enable Piper")
- return (False, "Failed to enable Piper")
-
- except Exception as e:
- logger.error(f"enable_piper failed: {e}")
- return (False, str(e))
-
- @rpc
- def disable_piper(self) -> tuple[bool, str]:
- """
- Disable Piper (convenience method).
-
- Returns:
- Tuple of (success, message)
- """
- try:
- result = self.piper.DisablePiper()
-
- if result:
- logger.info("Piper disabled")
- return (True, "Piper disabled")
- else:
- logger.warning("Failed to disable Piper")
- return (False, "Failed to disable Piper")
-
- except Exception as e:
- logger.error(f"disable_piper failed: {e}")
- return (False, str(e))
-
- # =========================================================================
- # Velocity Control Mode
- # =========================================================================
-
- @rpc
- def enable_velocity_control_mode(self) -> tuple[bool, str]:
- """
- Enable velocity control mode (integration-based).
-
- This switches the control loop to use velocity integration:
- - Velocity commands are integrated: position_target += velocity * dt
- - Integrated positions are sent to JointCtrl (standard position control)
- - Provides smooth velocity control interface while using proven position API
-
- Returns:
- Tuple of (success, message)
- """
- try:
- # Set config flag to enable velocity control
- # The control loop will integrate velocities to positions
- self.config.velocity_control = True
-
- logger.info("Velocity control mode enabled (integration-based)")
- return (True, "Velocity control mode enabled")
-
- except Exception as e:
- logger.error(f"enable_velocity_control_mode failed: {e}")
- self.config.velocity_control = False # Revert on exception
- return (False, str(e))
-
- @rpc
- def disable_velocity_control_mode(self) -> tuple[bool, str]:
- """
- Disable velocity control mode and return to position control.
-
- Returns:
- Tuple of (success, message)
- """
- try:
- # Set config flag to disable velocity control
- # The control loop will switch back to standard position control mode
- self.config.velocity_control = False
-
- # Reset position target to allow re-initialization when re-enabled
- self._position_target_ = None
-
- logger.info("Position control mode enabled (velocity mode disabled)")
- return (True, "Position control mode enabled")
-
- except Exception as e:
- logger.error(f"disable_velocity_control_mode failed: {e}")
- self.config.velocity_control = True # Revert on exception
- return (False, str(e))
diff --git a/dimos/hardware/manipulators/piper/piper_blueprints.py b/dimos/hardware/manipulators/piper/piper_blueprints.py
deleted file mode 100644
index 1145616841..0000000000
--- a/dimos/hardware/manipulators/piper/piper_blueprints.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Blueprints for Piper manipulator control systems.
-
-This module provides declarative blueprints for configuring Piper servo control,
-following the same pattern used for xArm and other manipulators.
-
-Usage:
- # Run via CLI:
- dimos run piper-servo # Driver only
- dimos run piper-cartesian # Driver + Cartesian motion controller
- dimos run piper-trajectory # Driver + Joint trajectory controller
-
- # Or programmatically:
- from dimos.hardware.manipulators.piper.piper_blueprints import piper_servo
- coordinator = piper_servo.build()
- coordinator.loop()
-"""
-
-from typing import Any
-
-from dimos.core.blueprints import autoconnect
-from dimos.core.transport import LCMTransport
-from dimos.hardware.manipulators.piper.piper_driver import piper_driver as piper_driver_blueprint
-from dimos.manipulation.control import cartesian_motion_controller, joint_trajectory_controller
-from dimos.msgs.geometry_msgs import PoseStamped
-from dimos.msgs.sensor_msgs import (
- JointCommand,
- JointState,
- RobotState,
-)
-from dimos.msgs.trajectory_msgs import JointTrajectory
-
-
-# Create a blueprint wrapper for the component-based driver
-def piper_driver(**config: Any) -> Any:
- """Create a blueprint for PiperDriver.
-
- Args:
- **config: Configuration parameters passed to PiperDriver
- - can_port: CAN interface name (default: "can0")
- - has_gripper: Whether gripper is attached (default: True)
- - enable_on_start: Whether to enable servos on start (default: True)
- - control_rate: Control loop + joint feedback rate in Hz (default: 100)
- - monitor_rate: Robot state monitoring rate in Hz (default: 10)
-
- Returns:
- Blueprint configuration for PiperDriver
- """
- # Set defaults
- config.setdefault("can_port", "can0")
- config.setdefault("has_gripper", True)
- config.setdefault("enable_on_start", True)
- config.setdefault("control_rate", 100)
- config.setdefault("monitor_rate", 10)
-
- # Return the piper_driver blueprint with the config
- return piper_driver_blueprint(**config)
-
-
-# =============================================================================
-# Piper Servo Control Blueprint
-# =============================================================================
-# PiperDriver configured for servo control mode using component-based architecture.
-# Publishes joint states and robot state, listens for joint commands.
-# =============================================================================
-
-piper_servo = piper_driver(
- can_port="can0",
- has_gripper=True,
- enable_on_start=True,
- control_rate=100,
- monitor_rate=10,
-).transports(
- {
- # Joint state feedback (position, velocity, effort)
- ("joint_state", JointState): LCMTransport("/piper/joint_states", JointState),
- # Robot state feedback (mode, state, errors)
- ("robot_state", RobotState): LCMTransport("/piper/robot_state", RobotState),
- # Position commands input
- ("joint_position_command", JointCommand): LCMTransport(
- "/piper/joint_position_command", JointCommand
- ),
- # Velocity commands input
- ("joint_velocity_command", JointCommand): LCMTransport(
- "/piper/joint_velocity_command", JointCommand
- ),
- }
-)
-
-# =============================================================================
-# Piper Cartesian Control Blueprint (Driver + Controller)
-# =============================================================================
-# Combines PiperDriver with CartesianMotionController for Cartesian space control.
-# The controller receives target_pose and converts to joint commands via IK.
-# =============================================================================
-
-piper_cartesian = autoconnect(
- piper_driver(
- can_port="can0",
- has_gripper=True,
- enable_on_start=True,
- control_rate=100,
- monitor_rate=10,
- ),
- cartesian_motion_controller(
- control_frequency=20.0,
- position_kp=5.0,
- position_ki=0.0,
- position_kd=0.1,
- max_linear_velocity=0.2,
- max_angular_velocity=1.0,
- ),
-).transports(
- {
- # Shared topics between driver and controller
- ("joint_state", JointState): LCMTransport("/piper/joint_states", JointState),
- ("robot_state", RobotState): LCMTransport("/piper/robot_state", RobotState),
- ("joint_position_command", JointCommand): LCMTransport(
- "/piper/joint_position_command", JointCommand
- ),
- # Controller-specific topics
- ("target_pose", PoseStamped): LCMTransport("/target_pose", PoseStamped),
- ("current_pose", PoseStamped): LCMTransport("/piper/current_pose", PoseStamped),
- }
-)
-
-# =============================================================================
-# Piper Trajectory Control Blueprint (Driver + Trajectory Controller)
-# =============================================================================
-# Combines PiperDriver with JointTrajectoryController for trajectory execution.
-# The controller receives JointTrajectory messages and executes them at 100Hz.
-# =============================================================================
-
-piper_trajectory = autoconnect(
- piper_driver(
- can_port="can0",
- has_gripper=True,
- enable_on_start=True,
- control_rate=100,
- monitor_rate=10,
- ),
- joint_trajectory_controller(
- control_frequency=100.0,
- ),
-).transports(
- {
- # Shared topics between driver and controller
- ("joint_state", JointState): LCMTransport("/piper/joint_states", JointState),
- ("robot_state", RobotState): LCMTransport("/piper/robot_state", RobotState),
- ("joint_position_command", JointCommand): LCMTransport(
- "/piper/joint_position_command", JointCommand
- ),
- # Trajectory input topic
- ("trajectory", JointTrajectory): LCMTransport("/trajectory", JointTrajectory),
- }
-)
-
-__all__ = ["piper_cartesian", "piper_servo", "piper_trajectory"]
diff --git a/dimos/hardware/manipulators/piper/piper_description.urdf b/dimos/hardware/manipulators/piper/piper_description.urdf
deleted file mode 100755
index c8a5a11ded..0000000000
--- a/dimos/hardware/manipulators/piper/piper_description.urdf
+++ /dev/null
@@ -1,497 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/dimos/hardware/manipulators/piper/piper_driver.py b/dimos/hardware/manipulators/piper/piper_driver.py
deleted file mode 100644
index 5730a4394a..0000000000
--- a/dimos/hardware/manipulators/piper/piper_driver.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Piper driver using the generalized component-based architecture."""
-
-import logging
-import time
-from typing import Any
-
-from dimos.hardware.manipulators.base import (
- BaseManipulatorDriver,
- StandardMotionComponent,
- StandardServoComponent,
- StandardStatusComponent,
-)
-
-from .piper_wrapper import PiperSDKWrapper
-
-logger = logging.getLogger(__name__)
-
-
-class PiperDriver(BaseManipulatorDriver):
- """Piper driver using component-based architecture.
-
- This driver supports the Piper 6-DOF manipulator via CAN bus.
- All the complex logic is handled by the base class and standard components.
- This file just assembles the pieces.
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the Piper driver.
-
- Args:
- **kwargs: Arguments for Module initialization.
- Driver configuration can be passed via 'config' keyword arg:
- - can_port: CAN interface name (e.g., 'can0')
- - has_gripper: Whether gripper is attached
- - enable_on_start: Whether to enable servos on start
- """
- # Extract driver-specific config from kwargs
- config: dict[str, Any] = kwargs.pop("config", {})
-
- # Extract driver-specific params that might be passed directly
- driver_params = [
- "can_port",
- "has_gripper",
- "enable_on_start",
- "control_rate",
- "monitor_rate",
- ]
- for param in driver_params:
- if param in kwargs:
- config[param] = kwargs.pop(param)
-
- logger.info(f"Initializing PiperDriver with config: {config}")
-
- # Create SDK wrapper
- sdk = PiperSDKWrapper()
-
- # Create standard components
- components = [
- StandardMotionComponent(sdk),
- StandardServoComponent(sdk),
- StandardStatusComponent(sdk),
- ]
-
- # Optional: Add gripper component if configured
- # if config.get('has_gripper', False):
- # from dimos.hardware.manipulators.base.components import StandardGripperComponent
- # components.append(StandardGripperComponent(sdk))
-
- # Remove any kwargs that would conflict with explicit arguments
- kwargs.pop("sdk", None)
- kwargs.pop("components", None)
- kwargs.pop("name", None)
-
- # Initialize base driver with SDK and components
- super().__init__(
- sdk=sdk, components=components, config=config, name="PiperDriver", **kwargs
- )
-
- # Initialize position target for velocity integration
- self._position_target: list[float] | None = None
- self._last_velocity_time: float = 0.0
-
- # Enable on start if configured
- if config.get("enable_on_start", False):
- logger.info("Enabling Piper servos on start...")
- servo_component = self.get_component(StandardServoComponent)
- if servo_component:
- result = servo_component.enable_servo()
- if result["success"]:
- logger.info("Piper servos enabled successfully")
- else:
- logger.warning(f"Failed to enable servos: {result.get('error')}")
-
- logger.info("PiperDriver initialized successfully")
-
- def _process_command(self, command: Any) -> None:
- """Override to implement velocity control via position integration.
-
- Args:
- command: Command to process
- """
- # Handle velocity commands specially for Piper
- if command.type == "velocity":
- # Piper doesn't have native velocity control - integrate to position
- current_time = time.time()
-
- # Initialize position target from current state on first velocity command
- if self._position_target is None:
- positions = self.shared_state.joint_positions
- if positions:
- self._position_target = list(positions)
- logger.info(
- f"Velocity control: Initialized position target from current state: {self._position_target}"
- )
- else:
- logger.warning("Cannot start velocity control - no current position available")
- return
-
- # Calculate dt since last velocity command
- if self._last_velocity_time > 0:
- dt = current_time - self._last_velocity_time
- else:
- dt = 1.0 / self.control_rate # Use nominal period for first command
-
- self._last_velocity_time = current_time
-
- # Integrate velocity to position: pos += vel * dt
- velocities = command.data["velocities"]
- for i in range(min(len(velocities), len(self._position_target))):
- self._position_target[i] += velocities[i] * dt
-
- # Send integrated position command
- success = self.sdk.set_joint_positions(
- self._position_target,
- velocity=1.0, # Use max velocity for responsiveness
- acceleration=1.0,
- wait=False,
- )
-
- if success:
- self.shared_state.target_positions = self._position_target
- self.shared_state.target_velocities = velocities
-
- else:
- # Reset velocity integration when switching to position mode
- if command.type == "position":
- self._position_target = None
- self._last_velocity_time = 0.0
-
- # Use base implementation for other command types
- super()._process_command(command)
-
-
-# Blueprint configuration for the driver
-def get_blueprint() -> dict[str, Any]:
- """Get the blueprint configuration for the Piper driver.
-
- Returns:
- Dictionary with blueprint configuration
- """
- return {
- "name": "PiperDriver",
- "class": PiperDriver,
- "config": {
- "can_port": "can0", # Default CAN interface
- "has_gripper": True, # Piper usually has gripper
- "enable_on_start": True, # Enable servos on startup
- "control_rate": 100, # Hz - control loop + joint feedback
- "monitor_rate": 10, # Hz - robot state monitoring
- },
- "inputs": {
- "joint_position_command": "JointCommand",
- "joint_velocity_command": "JointCommand",
- },
- "outputs": {
- "joint_state": "JointState",
- "robot_state": "RobotState",
- },
- "rpc_methods": [
- # Motion control
- "move_joint",
- "move_joint_velocity",
- "move_joint_effort",
- "stop_motion",
- "get_joint_state",
- "get_joint_limits",
- "get_velocity_limits",
- "set_velocity_scale",
- "set_acceleration_scale",
- "move_cartesian",
- "get_cartesian_state",
- "execute_trajectory",
- "stop_trajectory",
- # Servo control
- "enable_servo",
- "disable_servo",
- "toggle_servo",
- "get_servo_state",
- "emergency_stop",
- "reset_emergency_stop",
- "set_control_mode",
- "get_control_mode",
- "clear_errors",
- "reset_fault",
- "home_robot",
- "brake_release",
- "brake_engage",
- # Status monitoring
- "get_robot_state",
- "get_system_info",
- "get_capabilities",
- "get_error_state",
- "get_health_metrics",
- "get_statistics",
- "check_connection",
- "get_force_torque",
- "zero_force_torque",
- "get_digital_inputs",
- "set_digital_outputs",
- "get_analog_inputs",
- "get_gripper_state",
- ],
- }
-
-
-# Expose blueprint for declarative composition (compatible with dimos framework)
-piper_driver = PiperDriver.blueprint
diff --git a/dimos/hardware/manipulators/piper/piper_wrapper.py b/dimos/hardware/manipulators/piper/piper_wrapper.py
deleted file mode 100644
index 7384f6c06e..0000000000
--- a/dimos/hardware/manipulators/piper/piper_wrapper.py
+++ /dev/null
@@ -1,671 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Piper SDK wrapper implementation."""
-
-import logging
-import time
-from typing import Any
-
-from ..base.sdk_interface import BaseManipulatorSDK, ManipulatorInfo
-
-# Unit conversion constants
-RAD_TO_PIPER = 57295.7795 # radians to Piper units (0.001 degrees)
-PIPER_TO_RAD = 1.0 / RAD_TO_PIPER # Piper units to radians
-
-
-class PiperSDKWrapper(BaseManipulatorSDK):
- """SDK wrapper for Piper manipulators.
-
- This wrapper translates Piper's native SDK (which uses radians but 1-indexed joints)
- to our standard interface (0-indexed).
- """
-
- def __init__(self) -> None:
- """Initialize the Piper SDK wrapper."""
- self.logger = logging.getLogger(self.__class__.__name__)
- self.native_sdk: Any = None
- self.dof = 6 # Piper is always 6-DOF
- self._connected = False
- self._enabled = False
-
- # ============= Connection Management =============
-
- def connect(self, config: dict[str, Any]) -> bool:
- """Connect to Piper via CAN bus.
-
- Args:
- config: Configuration with 'can_port' (e.g., 'can0')
-
- Returns:
- True if connection successful
- """
- try:
- from piper_sdk import C_PiperInterface_V2
-
- can_port = config.get("can_port", "can0")
- self.logger.info(f"Connecting to Piper via CAN port {can_port}...")
-
- # Create Piper SDK instance
- self.native_sdk = C_PiperInterface_V2(
- can_name=can_port,
- judge_flag=True, # Enable safety checks
- can_auto_init=True, # Let SDK handle CAN initialization
- dh_is_offset=False,
- )
-
- # Connect to CAN port
- self.native_sdk.ConnectPort(piper_init=True, start_thread=True)
-
- # Wait for initialization
- time.sleep(0.025)
-
- # Check connection by trying to get status
- status = self.native_sdk.GetArmStatus()
- if status is not None:
- self._connected = True
-
- # Get firmware version
- try:
- version = self.native_sdk.GetPiperFirmwareVersion()
- self.logger.info(f"Connected to Piper (firmware: {version})")
- except:
- self.logger.info("Connected to Piper")
-
- return True
- else:
- self.logger.error("Failed to connect to Piper - no status received")
- return False
-
- except ImportError:
- self.logger.error("Piper SDK not installed. Please install piper_sdk")
- return False
- except Exception as e:
- self.logger.error(f"Connection failed: {e}")
- return False
-
- def disconnect(self) -> None:
- """Disconnect from Piper."""
- if self.native_sdk:
- try:
- # Disable arm first
- if self._enabled:
- self.native_sdk.DisablePiper()
- self._enabled = False
-
- # Disconnect
- self.native_sdk.DisconnectPort()
- self._connected = False
- self.logger.info("Disconnected from Piper")
- except:
- pass
- finally:
- self.native_sdk = None
-
- def is_connected(self) -> bool:
- """Check if connected to Piper.
-
- Returns:
- True if connected
- """
- if not self._connected or not self.native_sdk:
- return False
-
- # Try to get status to verify connection
- try:
- status = self.native_sdk.GetArmStatus()
- return status is not None
- except:
- return False
-
- # ============= Joint State Query =============
-
- def get_joint_positions(self) -> list[float]:
- """Get current joint positions.
-
- Returns:
- Joint positions in RADIANS (0-indexed)
- """
- joint_msgs = self.native_sdk.GetArmJointMsgs()
- if not joint_msgs or not joint_msgs.joint_state:
- raise RuntimeError("Failed to get Piper joint positions")
-
- # Get joint positions from joint_state (values are in Piper units: 0.001 degrees)
- # Convert to radians using PIPER_TO_RAD conversion factor
- joint_state = joint_msgs.joint_state
- positions = [
- joint_state.joint_1 * PIPER_TO_RAD, # Convert Piper units to radians
- joint_state.joint_2 * PIPER_TO_RAD,
- joint_state.joint_3 * PIPER_TO_RAD,
- joint_state.joint_4 * PIPER_TO_RAD,
- joint_state.joint_5 * PIPER_TO_RAD,
- joint_state.joint_6 * PIPER_TO_RAD,
- ]
- return positions
-
- def get_joint_velocities(self) -> list[float]:
- """Get current joint velocities.
-
- Returns:
- Joint velocities in RAD/S (0-indexed)
- """
- # TODO: Get actual velocities from Piper SDK
- # For now return zeros as velocity feedback may not be available
- return [0.0] * self.dof
-
- def get_joint_efforts(self) -> list[float]:
- """Get current joint efforts/torques.
-
- Returns:
- Joint efforts in Nm (0-indexed)
- """
- # TODO: Get actual efforts/torques from Piper SDK if available
- # For now return zeros as effort feedback may not be available
- return [0.0] * self.dof
-
- # ============= Joint Motion Control =============
-
- def set_joint_positions(
- self,
- positions: list[float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> bool:
- """Move joints to target positions.
-
- Args:
- positions: Target positions in RADIANS (0-indexed)
- velocity: Max velocity fraction (0-1)
- acceleration: Max acceleration fraction (0-1)
- wait: If True, block until motion completes
-
- Returns:
- True if command accepted
- """
- # Convert radians to Piper units (0.001 degrees)
- piper_joints = [round(rad * RAD_TO_PIPER) for rad in positions]
-
- # Optionally set motion control parameters based on velocity/acceleration
- if velocity < 1.0 or acceleration < 1.0:
- # Scale speed rate based on velocity parameter (0-100)
- speed_rate = int(velocity * 100)
- self.native_sdk.MotionCtrl_2(
- ctrl_mode=0x01, # CAN control mode
- move_mode=0x01, # Move mode
- move_spd_rate_ctrl=speed_rate, # Speed rate
- is_mit_mode=0x00, # Not MIT mode
- )
-
- # Send joint control command using JointCtrl with 6 individual parameters
- try:
- self.native_sdk.JointCtrl(
- piper_joints[0], # Joint 1
- piper_joints[1], # Joint 2
- piper_joints[2], # Joint 3
- piper_joints[3], # Joint 4
- piper_joints[4], # Joint 5
- piper_joints[5], # Joint 6
- )
- result = True
- except Exception as e:
- self.logger.error(f"Error setting joint positions: {e}")
- result = False
-
- # If wait requested, poll until motion completes
- if wait and result:
- start_time = time.time()
- timeout = 30.0 # 30 second timeout
-
- while time.time() - start_time < timeout:
- try:
- # Check if reached target (within tolerance)
- current = self.get_joint_positions()
- tolerance = 0.01 # radians
- if all(abs(current[i] - positions[i]) < tolerance for i in range(6)):
- break
- except:
- pass # Continue waiting
- time.sleep(0.01)
-
- return result
-
- def set_joint_velocities(self, velocities: list[float]) -> bool:
- """Set joint velocity targets.
-
- Note: Piper doesn't have native velocity control. The driver should
- implement velocity control via position integration if needed.
-
- Args:
- velocities: Target velocities in RAD/S (0-indexed)
-
- Returns:
- False - velocity control not supported at SDK level
- """
- # Piper doesn't have native velocity control
- # The driver layer should implement this via position integration
- self.logger.debug("Velocity control not supported at SDK level - use position integration")
- return False
-
- def set_joint_efforts(self, efforts: list[float]) -> bool:
- """Set joint effort/torque targets.
-
- Args:
- efforts: Target efforts in Nm (0-indexed)
-
- Returns:
- True if command accepted
- """
- # Check if torque control is supported
- if not hasattr(self.native_sdk, "SetJointTorque"):
- self.logger.warning("Torque control not available in this Piper version")
- return False
-
- # Convert 0-indexed to 1-indexed dict
- torque_dict = {i + 1: torque for i, torque in enumerate(efforts)}
-
- # Send torque command
- self.native_sdk.SetJointTorque(torque_dict)
- return True
-
- def stop_motion(self) -> bool:
- """Stop all ongoing motion.
-
- Returns:
- True if stop successful
- """
- # Piper emergency stop
- if hasattr(self.native_sdk, "EmergencyStop"):
- self.native_sdk.EmergencyStop()
- else:
- # Alternative: set zero velocities
- zero_vel = {i: 0.0 for i in range(1, 7)}
- if hasattr(self.native_sdk, "SetJointSpeed"):
- self.native_sdk.SetJointSpeed(zero_vel)
-
- return True
-
- # ============= Servo Control =============
-
- def enable_servos(self) -> bool:
- """Enable motor control.
-
- Returns:
- True if servos enabled
- """
- # Enable Piper
- attempts = 0
- max_attempts = 100
-
- while not self.native_sdk.EnablePiper() and attempts < max_attempts:
- time.sleep(0.01)
- attempts += 1
-
- if attempts < max_attempts:
- self._enabled = True
-
- # Set control mode
- self.native_sdk.MotionCtrl_2(
- ctrl_mode=0x01, # CAN control mode
- move_mode=0x01, # Move mode
- move_spd_rate_ctrl=30, # Speed rate
- is_mit_mode=0x00, # Not MIT mode
- )
-
- return True
-
- return False
-
- def disable_servos(self) -> bool:
- """Disable motor control.
-
- Returns:
- True if servos disabled
- """
- self.native_sdk.DisablePiper()
- self._enabled = False
- return True
-
- def are_servos_enabled(self) -> bool:
- """Check if servos are enabled.
-
- Returns:
- True if enabled
- """
- return self._enabled
-
- # ============= System State =============
-
- def get_robot_state(self) -> dict[str, Any]:
- """Get current robot state.
-
- Returns:
- State dictionary
- """
- status = self.native_sdk.GetArmStatus()
-
- if status and status.arm_status:
- # Map Piper states to standard states
- # Use the nested arm_status object
- arm_status = status.arm_status
-
- # Default state mapping
- state = 0 # idle
- mode = 0 # position mode
- error_code = 0
-
- # Check for error status
- if hasattr(arm_status, "err_code"):
- error_code = arm_status.err_code
- if error_code != 0:
- state = 2 # error state
-
- # Check motion status if available
- if hasattr(arm_status, "motion_status"):
- # Could check if moving
- pass
-
- return {
- "state": state,
- "mode": mode,
- "error_code": error_code,
- "warn_code": 0, # Piper doesn't have warn codes
- "is_moving": False, # Would need to track this
- "cmd_num": 0, # Piper doesn't expose command queue
- }
-
- return {
- "state": 2, # Error if can't get status
- "mode": 0,
- "error_code": 999,
- "warn_code": 0,
- "is_moving": False,
- "cmd_num": 0,
- }
-
- def get_error_code(self) -> int:
- """Get current error code.
-
- Returns:
- Error code (0 = no error)
- """
- status = self.native_sdk.GetArmStatus()
- if status and hasattr(status, "error_code"):
- return int(status.error_code)
- return 0
-
- def get_error_message(self) -> str:
- """Get human-readable error message.
-
- Returns:
- Error message string
- """
- error_code = self.get_error_code()
- if error_code == 0:
- return ""
-
- # Piper error codes (approximate)
- error_map = {
- 1: "Communication error",
- 2: "Motor error",
- 3: "Encoder error",
- 4: "Overtemperature",
- 5: "Overcurrent",
- 6: "Joint limit error",
- 7: "Emergency stop",
- 8: "Power error",
- }
-
- return error_map.get(error_code, f"Unknown error {error_code}")
-
- def clear_errors(self) -> bool:
- """Clear error states.
-
- Returns:
- True if errors cleared
- """
- if hasattr(self.native_sdk, "ClearError"):
- self.native_sdk.ClearError()
- return True
-
- # Alternative: disable and re-enable
- self.disable_servos()
- time.sleep(0.1)
- return self.enable_servos()
-
- def emergency_stop(self) -> bool:
- """Execute emergency stop.
-
- Returns:
- True if e-stop executed
- """
- if hasattr(self.native_sdk, "EmergencyStop"):
- self.native_sdk.EmergencyStop()
- return True
-
- # Alternative: disable servos
- return self.disable_servos()
-
- # ============= Information =============
-
- def get_info(self) -> ManipulatorInfo:
- """Get manipulator information.
-
- Returns:
- ManipulatorInfo object
- """
- firmware_version = None
- try:
- firmware_version = self.native_sdk.GetPiperFirmwareVersion()
- except:
- pass
-
- return ManipulatorInfo(
- vendor="Agilex",
- model="Piper",
- dof=self.dof,
- firmware_version=firmware_version,
- serial_number=None, # Piper doesn't expose serial number
- )
-
- def get_joint_limits(self) -> tuple[list[float], list[float]]:
- """Get joint position limits.
-
- Returns:
- Tuple of (lower_limits, upper_limits) in RADIANS
- """
- # Piper joint limits (approximate, in radians)
- lower_limits = [-3.14, -2.35, -2.35, -3.14, -2.35, -3.14]
- upper_limits = [3.14, 2.35, 2.35, 3.14, 2.35, 3.14]
-
- return (lower_limits, upper_limits)
-
- def get_velocity_limits(self) -> list[float]:
- """Get joint velocity limits.
-
- Returns:
- Maximum velocities in RAD/S
- """
- # Piper max velocities (approximate)
- max_vel = 3.14 # rad/s
- return [max_vel] * self.dof
-
- def get_acceleration_limits(self) -> list[float]:
- """Get joint acceleration limits.
-
- Returns:
- Maximum accelerations in RAD/S²
- """
- # Piper max accelerations (approximate)
- max_acc = 10.0 # rad/s²
- return [max_acc] * self.dof
-
- # ============= Optional Methods =============
-
- def get_cartesian_position(self) -> dict[str, float] | None:
- """Get current end-effector pose.
-
- Returns:
- Pose dict or None if not supported
- """
- if hasattr(self.native_sdk, "GetEndPose"):
- pose = self.native_sdk.GetEndPose()
- if pose:
- return {
- "x": pose.x,
- "y": pose.y,
- "z": pose.z,
- "roll": pose.roll,
- "pitch": pose.pitch,
- "yaw": pose.yaw,
- }
- return None
-
- def set_cartesian_position(
- self,
- pose: dict[str, float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> bool:
- """Move end-effector to target pose.
-
- Args:
- pose: Target pose dict
- velocity: Max velocity fraction (0-1)
- acceleration: Max acceleration fraction (0-1)
- wait: Block until complete
-
- Returns:
- True if command accepted
- """
- if not hasattr(self.native_sdk, "MoveL"):
- self.logger.warning("Cartesian control not available")
- return False
-
- # Create pose object for Piper
- target = {
- "x": pose["x"],
- "y": pose["y"],
- "z": pose["z"],
- "roll": pose["roll"],
- "pitch": pose["pitch"],
- "yaw": pose["yaw"],
- }
-
- # Send Cartesian command
- self.native_sdk.MoveL(target)
-
- # Wait if requested
- if wait:
- start_time = time.time()
- timeout = 30.0
-
- while time.time() - start_time < timeout:
- current = self.get_cartesian_position()
- if current:
- # Check if reached target (within tolerance)
- tol_pos = 0.005 # 5mm
- tol_rot = 0.05 # ~3 degrees
-
- if (
- abs(current["x"] - pose["x"]) < tol_pos
- and abs(current["y"] - pose["y"]) < tol_pos
- and abs(current["z"] - pose["z"]) < tol_pos
- and abs(current["roll"] - pose["roll"]) < tol_rot
- and abs(current["pitch"] - pose["pitch"]) < tol_rot
- and abs(current["yaw"] - pose["yaw"]) < tol_rot
- ):
- break
-
- time.sleep(0.01)
-
- return True
-
- def get_gripper_position(self) -> float | None:
- """Get gripper position.
-
- Returns:
- Position in meters or None
- """
- if hasattr(self.native_sdk, "GetGripperState"):
- state = self.native_sdk.GetGripperState()
- if state:
- # Piper gripper position is 0-100 (percentage)
- # Convert to meters (assume max opening 0.08m)
- return float(state / 100.0) * 0.08
- return None
-
- def set_gripper_position(self, position: float, force: float = 1.0) -> bool:
- """Set gripper position.
-
- Args:
- position: Target position in meters
- force: Force fraction (0-1)
-
- Returns:
- True if successful
- """
- if not hasattr(self.native_sdk, "GripperCtrl"):
- self.logger.warning("Gripper control not available")
- return False
-
- # Convert meters to percentage (0-100)
- # Assume max opening 0.08m
- percentage = int((position / 0.08) * 100)
- percentage = max(0, min(100, percentage))
-
- # Control gripper
- self.native_sdk.GripperCtrl(percentage)
- return True
-
- def set_control_mode(self, mode: str) -> bool:
- """Set control mode.
-
- Args:
- mode: 'position', 'velocity', 'torque', or 'impedance'
-
- Returns:
- True if successful
- """
- # Piper modes via MotionCtrl_2
- # ctrl_mode: 0x01=CAN control
- # move_mode: 0x01=position, 0x02=velocity?
-
- if not hasattr(self.native_sdk, "MotionCtrl_2"):
- return False
-
- move_mode = 0x01 # Default position
- if mode == "velocity":
- move_mode = 0x02
-
- self.native_sdk.MotionCtrl_2(
- ctrl_mode=0x01, move_mode=move_mode, move_spd_rate_ctrl=30, is_mit_mode=0x00
- )
-
- return True
-
- def get_control_mode(self) -> str | None:
- """Get current control mode.
-
- Returns:
- Mode string or None
- """
- status = self.native_sdk.GetArmStatus()
- if status and hasattr(status, "arm_mode"):
- # Map Piper modes
- mode_map = {0x01: "position", 0x02: "velocity"}
- return mode_map.get(status.arm_mode, "unknown")
-
- return "position" # Default assumption
diff --git a/dimos/hardware/manipulators/spec.py b/dimos/hardware/manipulators/spec.py
new file mode 100644
index 0000000000..585043421e
--- /dev/null
+++ b/dimos/hardware/manipulators/spec.py
@@ -0,0 +1,261 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Manipulator specifications: Protocol and shared types.
+
+This file defines:
+1. Shared enums and dataclasses used by all arms
+2. ManipulatorBackend Protocol that backends must implement
+
+Note: No ABC for drivers. Each arm implements its own driver
+with full control over threading and logic.
+"""
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import Protocol, runtime_checkable
+
+from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3
+
+# ============================================================================
+# SHARED TYPES
+# ============================================================================
+
+
+class DriverStatus(Enum):
+ """Status returned by driver operations."""
+
+ DISCONNECTED = "disconnected"
+ CONNECTED = "connected"
+ ENABLED = "enabled"
+ MOVING = "moving"
+ ERROR = "error"
+
+
+class ControlMode(Enum):
+ """Control modes for manipulator."""
+
+ POSITION = "position" # Planned position control (slower, smoother)
+ SERVO_POSITION = "servo_position" # High-freq joint position streaming (100Hz+)
+ VELOCITY = "velocity"
+ TORQUE = "torque"
+ CARTESIAN = "cartesian"
+ CARTESIAN_VELOCITY = "cartesian_velocity"
+ IMPEDANCE = "impedance"
+
+
+@dataclass
+class ManipulatorInfo:
+ """Information about the manipulator."""
+
+ vendor: str
+ model: str
+ dof: int
+ firmware_version: str | None = None
+ serial_number: str | None = None
+
+
+@dataclass
+class JointLimits:
+ """Joint position and velocity limits."""
+
+ position_lower: list[float] # radians
+ position_upper: list[float] # radians
+ velocity_max: list[float] # rad/s
+
+
+def default_base_transform() -> Transform:
+ """Default identity transform for arm mounting."""
+ return Transform(
+ translation=Vector3(0.0, 0.0, 0.0),
+ rotation=Quaternion(0.0, 0.0, 0.0, 1.0),
+ )
+
+
+# ============================================================================
+# BACKEND PROTOCOL
+# ============================================================================
+
+
+@runtime_checkable
+class ManipulatorBackend(Protocol):
+ """Protocol for hardware-specific IO.
+
+ Implement this per vendor SDK. All methods use SI units:
+ - Angles: radians
+ - Angular velocity: rad/s
+ - Torque: Nm
+ - Position: meters
+ - Force: Newtons
+ """
+
+ # --- Connection ---
+
+ def connect(self) -> bool:
+ """Connect to hardware. Returns True on success."""
+ ...
+
+ def disconnect(self) -> None:
+ """Disconnect from hardware."""
+ ...
+
+ def is_connected(self) -> bool:
+ """Check if connected."""
+ ...
+
+ # --- Info ---
+
+ def get_info(self) -> ManipulatorInfo:
+ """Get manipulator info (vendor, model, DOF)."""
+ ...
+
+ def get_dof(self) -> int:
+ """Get degrees of freedom."""
+ ...
+
+ def get_limits(self) -> JointLimits:
+ """Get joint limits."""
+ ...
+
+ # --- Control Mode ---
+
+ def set_control_mode(self, mode: ControlMode) -> bool:
+ """Set control mode (position, velocity, torque, cartesian, etc).
+
+ Args:
+ mode: Target control mode
+
+ Returns:
+ True if mode switch successful, False otherwise
+
+ Note: Some arms (like XArm) may accept commands in any mode,
+ while others (like Piper) require explicit mode switching.
+ """
+ ...
+
+ def get_control_mode(self) -> ControlMode:
+ """Get current control mode.
+
+ Returns:
+ Current control mode
+ """
+ ...
+
+ # --- State Reading ---
+
+ def read_joint_positions(self) -> list[float]:
+ """Read current joint positions (radians)."""
+ ...
+
+ def read_joint_velocities(self) -> list[float]:
+ """Read current joint velocities (rad/s)."""
+ ...
+
+ def read_joint_efforts(self) -> list[float]:
+ """Read current joint efforts (Nm)."""
+ ...
+
+ def read_state(self) -> dict[str, int]:
+ """Read robot state (mode, state code, etc)."""
+ ...
+
+ def read_error(self) -> tuple[int, str]:
+ """Read error code and message. (0, '') means no error."""
+ ...
+
+ # --- Motion Control (Joint Space) ---
+
+ def write_joint_positions(
+ self,
+ positions: list[float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Command joint positions (radians). Returns success."""
+ ...
+
+ def write_joint_velocities(self, velocities: list[float]) -> bool:
+ """Command joint velocities (rad/s). Returns success."""
+ ...
+
+ def write_stop(self) -> bool:
+ """Stop all motion immediately."""
+ ...
+
+ # --- Servo Control ---
+
+ def write_enable(self, enable: bool) -> bool:
+ """Enable or disable servos. Returns success."""
+ ...
+
+ def read_enabled(self) -> bool:
+ """Check if servos are enabled."""
+ ...
+
+ def write_clear_errors(self) -> bool:
+ """Clear error state. Returns success."""
+ ...
+
+ # --- Optional: Cartesian Control ---
+ # Return None/False if not supported
+
+ def read_cartesian_position(self) -> dict[str, float] | None:
+ """Read end-effector pose.
+
+ Returns:
+ Dict with keys: x, y, z (meters), roll, pitch, yaw (radians)
+ None if not supported
+ """
+ ...
+
+ def write_cartesian_position(
+ self,
+ pose: dict[str, float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Command end-effector pose.
+
+ Args:
+ pose: Dict with keys: x, y, z (meters), roll, pitch, yaw (radians)
+ velocity: Speed as fraction of max (0-1)
+
+ Returns:
+ True if command accepted, False if not supported
+ """
+ ...
+
+ # --- Optional: Gripper ---
+
+ def read_gripper_position(self) -> float | None:
+ """Read gripper position (meters). None if no gripper."""
+ ...
+
+ def write_gripper_position(self, position: float) -> bool:
+ """Command gripper position. False if no gripper."""
+ ...
+
+ # --- Optional: Force/Torque Sensor ---
+
+ def read_force_torque(self) -> list[float] | None:
+ """Read F/T sensor [fx, fy, fz, tx, ty, tz]. None if no sensor."""
+ ...
+
+
+__all__ = [
+ "ControlMode",
+ "DriverStatus",
+ "JointLimits",
+ "ManipulatorBackend",
+ "ManipulatorInfo",
+ "default_base_transform",
+]
diff --git a/dimos/hardware/manipulators/test_integration_runner.py b/dimos/hardware/manipulators/test_integration_runner.py
deleted file mode 100644
index eab6a022da..0000000000
--- a/dimos/hardware/manipulators/test_integration_runner.py
+++ /dev/null
@@ -1,626 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Integration test runner for manipulator drivers.
-
-This is a standalone script (NOT a pytest test file) that tests the common
-BaseManipulatorDriver interface that all arms implement.
-Supports both mock mode (for CI/CD) and hardware mode (for real testing).
-
-NOTE: This file is intentionally NOT named test_*.py to avoid pytest auto-discovery.
-For pytest-based unit tests, see: dimos/hardware/manipulators/base/tests/test_driver_unit.py
-
-Usage:
- # Run with mock (CI/CD safe, default)
- python -m dimos.hardware.manipulators.integration_test_runner
-
- # Run specific arm with mock
- python -m dimos.hardware.manipulators.integration_test_runner --arm piper
-
- # Run with real hardware (xArm)
- python -m dimos.hardware.manipulators.integration_test_runner --hardware --ip 192.168.1.210
-
- # Run with real hardware (Piper)
- python -m dimos.hardware.manipulators.integration_test_runner --hardware --arm piper --can can0
-
- # Run specific test
- python -m dimos.hardware.manipulators.integration_test_runner --test connection
-
- # Skip motion tests (safer for hardware)
- python -m dimos.hardware.manipulators.integration_test_runner --hardware --skip-motion
-"""
-
-import argparse
-import math
-import sys
-import time
-
-from dimos.core.transport import LCMTransport
-from dimos.hardware.manipulators.base.sdk_interface import BaseManipulatorSDK, ManipulatorInfo
-from dimos.msgs.sensor_msgs import JointState, RobotState
-
-
-class MockSDK(BaseManipulatorSDK):
- """Mock SDK for testing without hardware. Works for any arm type."""
-
- def __init__(self, dof: int = 6, vendor: str = "Mock", model: str = "TestArm"):
- self._connected = True
- self._dof = dof
- self._vendor = vendor
- self._model = model
- self._positions = [0.0] * dof
- self._velocities = [0.0] * dof
- self._efforts = [0.0] * dof
- self._servos_enabled = False
- self._mode = 0
- self._state = 0
- self._error_code = 0
-
- def connect(self, config: dict) -> bool:
- self._connected = True
- return True
-
- def disconnect(self) -> None:
- self._connected = False
-
- def is_connected(self) -> bool:
- return self._connected
-
- def get_joint_positions(self) -> list[float]:
- return self._positions.copy()
-
- def get_joint_velocities(self) -> list[float]:
- return self._velocities.copy()
-
- def get_joint_efforts(self) -> list[float]:
- return self._efforts.copy()
-
- def set_joint_positions(
- self,
- positions: list[float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> bool:
- if not self._servos_enabled:
- return False
- self._positions = list(positions)
- return True
-
- def set_joint_velocities(self, velocities: list[float]) -> bool:
- if not self._servos_enabled:
- return False
- self._velocities = list(velocities)
- return True
-
- def set_joint_efforts(self, efforts: list[float]) -> bool:
- return False # Not supported in mock
-
- def stop_motion(self) -> bool:
- self._velocities = [0.0] * self._dof
- return True
-
- def enable_servos(self) -> bool:
- self._servos_enabled = True
- return True
-
- def disable_servos(self) -> bool:
- self._servos_enabled = False
- return True
-
- def are_servos_enabled(self) -> bool:
- return self._servos_enabled
-
- def get_robot_state(self) -> dict:
- return {
- "state": self._state,
- "mode": self._mode,
- "error_code": self._error_code,
- "is_moving": any(v != 0 for v in self._velocities),
- }
-
- def get_error_code(self) -> int:
- return self._error_code
-
- def get_error_message(self) -> str:
- return "" if self._error_code == 0 else f"Error {self._error_code}"
-
- def clear_errors(self) -> bool:
- self._error_code = 0
- return True
-
- def emergency_stop(self) -> bool:
- self._velocities = [0.0] * self._dof
- self._servos_enabled = False
- return True
-
- def get_info(self) -> ManipulatorInfo:
- return ManipulatorInfo(
- vendor=self._vendor,
- model=f"{self._model} (Mock)",
- dof=self._dof,
- firmware_version="mock-1.0.0",
- serial_number="MOCK-001",
- )
-
- def get_joint_limits(self) -> tuple[list[float], list[float]]:
- lower = [-2 * math.pi] * self._dof
- upper = [2 * math.pi] * self._dof
- return lower, upper
-
- def get_velocity_limits(self) -> list[float]:
- return [math.pi] * self._dof
-
- def get_acceleration_limits(self) -> list[float]:
- return [math.pi * 2] * self._dof
-
-
-# =============================================================================
-# Test Functions (work with any driver implementing BaseManipulatorDriver)
-# =============================================================================
-
-
-def check_connection(driver, hardware: bool) -> bool:
- """Test that driver connects to hardware/mock."""
- print("Testing connection...")
-
- if not driver.sdk.is_connected():
- print(" FAIL: SDK not connected")
- return False
-
- info = driver.sdk.get_info()
- print(f" Connected to: {info.vendor} {info.model}")
- print(f" DOF: {info.dof}")
- print(f" Firmware: {info.firmware_version}")
- print(f" Mode: {'HARDWARE' if hardware else 'MOCK'}")
- print(" PASS")
- return True
-
-
-def check_read_joint_state(driver, hardware: bool) -> bool:
- """Test reading joint state."""
- print("Testing read joint state...")
-
- result = driver.get_joint_state()
- if not result.get("success"):
- print(f" FAIL: {result.get('error')}")
- return False
-
- positions = result["positions"]
- velocities = result["velocities"]
- efforts = result["efforts"]
-
- print(f" Positions (deg): {[f'{math.degrees(p):.1f}' for p in positions]}")
- print(f" Velocities: {[f'{v:.3f}' for v in velocities]}")
- print(f" Efforts: {[f'{e:.2f}' for e in efforts]}")
-
- if len(positions) != driver.capabilities.dof:
- print(f" FAIL: Expected {driver.capabilities.dof} joints, got {len(positions)}")
- return False
-
- print(" PASS")
- return True
-
-
-def check_get_robot_state(driver, hardware: bool) -> bool:
- """Test getting robot state."""
- print("Testing robot state...")
-
- result = driver.get_robot_state()
- if not result.get("success"):
- print(f" FAIL: {result.get('error')}")
- return False
-
- print(f" State: {result.get('state')}")
- print(f" Mode: {result.get('mode')}")
- print(f" Error code: {result.get('error_code')}")
- print(f" Is moving: {result.get('is_moving')}")
- print(" PASS")
- return True
-
-
-def check_servo_enable_disable(driver, hardware: bool) -> bool:
- """Test enabling and disabling servos."""
- print("Testing servo enable/disable...")
-
- # Enable
- result = driver.enable_servo()
- if not result.get("success"):
- print(f" FAIL enable: {result.get('error')}")
- return False
- print(" Enabled servos")
-
- # Hardware needs more time for state to propagate
- time.sleep(1.0 if hardware else 0.01)
-
- # Check state with retry for hardware
- enabled = driver.sdk.are_servos_enabled()
- if not enabled and hardware:
- # Retry after additional delay
- time.sleep(0.5)
- enabled = driver.sdk.are_servos_enabled()
-
- if not enabled:
- print(" FAIL: Servos not enabled after enable_servo()")
- return False
- print(" Verified servos enabled")
-
- # # Disable
- # result = driver.disable_servo()
- # if not result.get("success"):
- # print(f" FAIL disable: {result.get('error')}")
- # return False
- # print(" Disabled servos")
-
- print(" PASS")
- return True
-
-
-def check_joint_limits(driver, hardware: bool) -> bool:
- """Test getting joint limits."""
- print("Testing joint limits...")
-
- result = driver.get_joint_limits()
- if not result.get("success"):
- print(f" FAIL: {result.get('error')}")
- return False
-
- lower = result["lower"]
- upper = result["upper"]
-
- print(f" Lower (deg): {[f'{math.degrees(l):.1f}' for l in lower]}")
- print(f" Upper (deg): {[f'{math.degrees(u):.1f}' for u in upper]}")
-
- if len(lower) != driver.capabilities.dof:
- print(" FAIL: Wrong number of limits")
- return False
-
- print(" PASS")
- return True
-
-
-def check_stop_motion(driver, hardware: bool) -> bool:
- """Test stop motion command."""
- print("Testing stop motion...")
-
- result = driver.stop_motion()
- # Note: stop_motion may return success=False if arm isn't moving,
- # which is expected behavior. We just verify no exception occurred.
- if result is None:
- print(" FAIL: stop_motion returned None")
- return False
-
- if result.get("error"):
- print(f" FAIL: {result.get('error')}")
- return False
-
- # success=False when not moving is OK, success=True is also OK
- print(f" stop_motion returned success={result.get('success')}")
- print(" PASS")
- return True
-
-
-def check_small_motion(driver, hardware: bool) -> bool:
- """Test a small joint motion (5 degrees on joint 1).
-
- WARNING: With --hardware, this MOVES the real robot!
- """
- print("Testing small motion (5 deg on J1)...")
- if hardware:
- print(" WARNING: Robot will move!")
-
- # Get current position
- result = driver.get_joint_state()
- if not result.get("success"):
- print(f" FAIL: Cannot read state: {result.get('error')}")
- return False
-
- current_pos = list(result["positions"])
- print(f" Current J1: {math.degrees(current_pos[0]):.2f} deg")
-
- driver.clear_errors()
- # print(driver.get_state())
-
- # Enable servos
- result = driver.enable_servo()
- print(result)
- if not result.get("success"):
- print(f" FAIL: Cannot enable servos: {result.get('error')}")
- return False
-
- time.sleep(0.5 if hardware else 0.01)
-
- # Move +5 degrees on joint 1
- target_pos = current_pos.copy()
- target_pos[0] += math.radians(5.0)
- print(f" Target J1: {math.degrees(target_pos[0]):.2f} deg")
-
- result = driver.move_joint(target_pos, velocity=0.3, wait=True)
- if not result.get("success"):
- print(f" FAIL: Motion failed: {result.get('error')}")
- return False
-
- time.sleep(1.0 if hardware else 0.01)
-
- # Verify position
- result = driver.get_joint_state()
- new_pos = result["positions"]
- error = abs(new_pos[0] - target_pos[0])
- print(
- f" Reached J1: {math.degrees(new_pos[0]):.2f} deg (error: {math.degrees(error):.3f} deg)"
- )
-
- if hardware and error > math.radians(1.0): # Allow 1 degree error for real hardware
- print(" FAIL: Position error too large")
- return False
-
- # Move back
- print(" Moving back to original position...")
- driver.move_joint(current_pos, velocity=0.3, wait=True)
- time.sleep(1.0 if hardware else 0.01)
-
- print(" PASS")
- return True
-
-
-# =============================================================================
-# Driver Factory
-# =============================================================================
-
-
-def create_driver(arm: str, hardware: bool, config: dict):
- """Create driver for the specified arm type.
-
- Args:
- arm: Arm type ('xarm', 'piper', etc.)
- hardware: If True, use real hardware; if False, use mock SDK
- config: Configuration dict (ip, dof, etc.)
-
- Returns:
- Driver instance
- """
- if arm == "xarm":
- from dimos.hardware.manipulators.xarm.xarm_driver import XArmDriver
-
- if hardware:
- return XArmDriver(config=config)
- else:
- # Create driver with mock SDK
- driver = XArmDriver.__new__(XArmDriver)
- # Manually initialize with mock
- from dimos.hardware.manipulators.base import (
- BaseManipulatorDriver,
- StandardMotionComponent,
- StandardServoComponent,
- StandardStatusComponent,
- )
-
- mock_sdk = MockSDK(dof=config.get("dof", 6), vendor="UFactory", model="xArm")
- components = [
- StandardMotionComponent(),
- StandardServoComponent(),
- StandardStatusComponent(),
- ]
- BaseManipulatorDriver.__init__(
- driver, sdk=mock_sdk, components=components, config=config, name="XArmDriver"
- )
- return driver
-
- elif arm == "piper":
- from dimos.hardware.manipulators.piper.piper_driver import PiperDriver
-
- if hardware:
- return PiperDriver(config=config)
- else:
- # Create driver with mock SDK
- driver = PiperDriver.__new__(PiperDriver)
- from dimos.hardware.manipulators.base import (
- BaseManipulatorDriver,
- StandardMotionComponent,
- StandardServoComponent,
- StandardStatusComponent,
- )
-
- mock_sdk = MockSDK(dof=6, vendor="Agilex", model="Piper")
- components = [
- StandardMotionComponent(),
- StandardServoComponent(),
- StandardStatusComponent(),
- ]
- BaseManipulatorDriver.__init__(
- driver, sdk=mock_sdk, components=components, config=config, name="PiperDriver"
- )
- return driver
-
- else:
- raise ValueError(f"Unknown arm type: {arm}. Supported: xarm, piper")
-
-
-# =============================================================================
-# Test Runner
-# =============================================================================
-
-
-def configure_transports(driver, arm: str):
- """Configure LCM transports for the driver (like production does).
-
- Args:
- driver: The driver instance
- arm: Arm type for topic naming
- """
- # Create LCM transports for state publishing
- joint_state_transport = LCMTransport(f"/test/{arm}/joint_state", JointState)
- robot_state_transport = LCMTransport(f"/test/{arm}/robot_state", RobotState)
-
- # Set transports on driver's Out streams
- if driver.joint_state:
- driver.joint_state._transport = joint_state_transport
- if driver.robot_state:
- driver.robot_state._transport = robot_state_transport
-
-
-def run_tests(
- arm: str,
- hardware: bool,
- config: dict,
- test_name: str | None = None,
- skip_motion: bool = False,
-):
- """Run integration tests."""
- mode = "HARDWARE" if hardware else "MOCK"
- print("=" * 60)
- print(f"Manipulator Driver Integration Tests ({mode})")
- print("=" * 60)
- print(f"Arm: {arm}")
- print(f"Config: {config}")
- print()
-
- # Create driver
- print("Creating driver...")
- try:
- driver = create_driver(arm, hardware, config)
- except Exception as e:
- print(f"FATAL: Failed to create driver: {e}")
- return False
-
- # Configure transports (like production does)
- print("Configuring transports...")
- configure_transports(driver, arm)
-
- # Start driver
- print("Starting driver...")
- try:
- driver.start()
- # Piper needs more initialization time before commands work
- wait_time = 3.0 if (hardware and arm == "piper") else (1.0 if hardware else 0.1)
- time.sleep(wait_time)
- except Exception as e:
- print(f"FATAL: Failed to start driver: {e}")
- return False
-
- # Define tests (stop_motion last since it leaves arm in stopped state)
- tests = [
- ("connection", check_connection),
- ("read_state", check_read_joint_state),
- ("robot_state", check_get_robot_state),
- ("joint_limits", check_joint_limits),
- # ("servo", check_servo_enable_disable),
- ]
-
- if not skip_motion:
- tests.append(("motion", check_small_motion))
-
- # Stop test always last (leaves arm in stopped state)
- tests.append(("stop", check_stop_motion))
-
- # Run tests
- results = {}
- print()
- print("-" * 60)
-
- for name, test_func in tests:
- if test_name and name != test_name:
- continue
-
- try:
- results[name] = test_func(driver, hardware)
- except Exception as e:
- print(f" EXCEPTION: {e}")
- import traceback
-
- traceback.print_exc()
- results[name] = False
-
- print()
-
- # Stop driver
- print("Stopping driver...")
- try:
- driver.stop()
- except Exception as e:
- print(f"Warning: Error stopping driver: {e}")
-
- # Summary
- print("-" * 60)
- print("SUMMARY")
- print("-" * 60)
- passed = sum(1 for r in results.values() if r)
- total = len(results)
-
- for name, result in results.items():
- status = "PASS" if result else "FAIL"
- print(f" {name}: {status}")
-
- print()
- print(f"Result: {passed}/{total} tests passed")
-
- return passed == total
-
-
-def main():
- parser = argparse.ArgumentParser(
- description="Generic manipulator driver integration tests",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
-Examples:
- # Mock mode (CI/CD safe, default)
- python -m dimos.hardware.manipulators.integration_test_runner
-
- # xArm hardware mode
- python -m dimos.hardware.manipulators.integration_test_runner --hardware --ip 192.168.1.210
-
- # Piper hardware mode
- python -m dimos.hardware.manipulators.integration_test_runner --hardware --arm piper --can can0
-
- # Skip motion tests
- python -m dimos.hardware.manipulators.integration_test_runner --hardware --skip-motion
-""",
- )
- parser.add_argument(
- "--arm", default="xarm", choices=["xarm", "piper"], help="Arm type to test (default: xarm)"
- )
- parser.add_argument(
- "--hardware", action="store_true", help="Use real hardware (default: mock mode)"
- )
- parser.add_argument(
- "--ip", default="192.168.1.210", help="IP address for xarm (default: 192.168.1.210)"
- )
- parser.add_argument("--can", default="can0", help="CAN interface for piper (default: can0)")
- parser.add_argument(
- "--dof", type=int, help="Degrees of freedom (auto-detected in hardware mode)"
- )
- parser.add_argument("--test", help="Run specific test only")
- parser.add_argument("--skip-motion", action="store_true", help="Skip motion tests")
- args = parser.parse_args()
-
- # Build config - DOF auto-detected from hardware if not specified
- config = {}
- if args.arm == "xarm" and args.ip:
- config["ip"] = args.ip
- if args.arm == "piper" and args.can:
- config["can_port"] = args.can
- if args.dof:
- config["dof"] = args.dof
- elif not args.hardware:
- # Mock mode needs explicit DOF
- config["dof"] = 6
-
- success = run_tests(args.arm, args.hardware, config, args.test, args.skip_motion)
- sys.exit(0 if success else 1)
-
-
-if __name__ == "__main__":
- main()
diff --git a/dimos/hardware/manipulators/xarm/README.md b/dimos/hardware/manipulators/xarm/README.md
deleted file mode 100644
index ff7a797cad..0000000000
--- a/dimos/hardware/manipulators/xarm/README.md
+++ /dev/null
@@ -1,149 +0,0 @@
-# xArm Driver for dimos
-
-Real-time driver for UFACTORY xArm5/6/7 manipulators integrated with the dimos framework.
-
-## Quick Start
-
-### 1. Specify Robot IP
-
-**On boot** (Important)
-```bash
-sudo ifconfig lo multicast
-sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo
-```
-
-**Option A: Command-line argument** (recommended)
-```bash
-python test_xarm_driver.py --ip 192.168.1.235
-python interactive_control.py --ip 192.168.1.235
-```
-
-**Option B: Environment variable**
-```bash
-export XARM_IP=192.168.1.235
-python test_xarm_driver.py
-```
-
-**Option C: Use default** (192.168.1.235)
-```bash
-python test_xarm_driver.py # Uses default
-```
-
-**Note:** Command-line `--ip` takes precedence over `XARM_IP` environment variable.
-
-### 2. Basic Usage
-
-```python
-from dimos import core
-from dimos.hardware.manipulators.xarm.xarm_driver import XArmDriver
-from dimos.msgs.sensor_msgs import JointState, JointCommand
-
-# Start dimos and deploy driver
-dimos = core.start(1)
-xarm = dimos.deploy(XArmDriver, ip_address="192.168.1.235", xarm_type="xarm6")
-
-# Configure LCM transports
-xarm.joint_state.transport = core.LCMTransport("/xarm/joint_states", JointState)
-xarm.joint_position_command.transport = core.LCMTransport("/xarm/joint_commands", JointCommand)
-
-# Start and enable servo mode
-xarm.start()
-xarm.enable_servo_mode()
-
-# Control via RPC
-xarm.set_joint_angles([0, 0, 0, 0, 0, 0], speed=50, mvacc=100, mvtime=0)
-
-# Cleanup
-xarm.stop()
-dimos.stop()
-```
-
-## Key Features
-
-- **100Hz control loop** for real-time position/velocity control
-- **LCM pub/sub** for distributed system integration
-- **RPC methods** for direct hardware control
-- **Position mode** (radians) and **velocity mode** (deg/s)
-- **Component-based API**: motion, kinematics, system, gripper control
-
-## Topics
-
-**Subscribed:**
-- `/xarm/joint_position_command` - JointCommand (positions in radians)
-- `/xarm/joint_velocity_command` - JointCommand (velocities in deg/s)
-
-**Published:**
-- `/xarm/joint_states` - JointState (100Hz)
-- `/xarm/robot_state` - RobotState (10Hz)
-- `/xarm/ft_ext`, `/xarm/ft_raw` - WrenchStamped (force/torque)
-
-## Common RPC Methods
-
-```python
-# System control
-xarm.enable_servo_mode() # Enable position control (mode 1)
-xarm.enable_velocity_control_mode() # Enable velocity control (mode 4)
-xarm.motion_enable(True) # Enable motors
-xarm.clean_error() # Clear errors
-
-# Motion control
-xarm.set_joint_angles([...], speed=50, mvacc=100, mvtime=0)
-xarm.set_servo_angle(joint_id=5, angle=0.5, speed=50)
-
-# State queries
-state = xarm.get_joint_state()
-position = xarm.get_position()
-```
-
-## Configuration
-
-Key parameters for `XArmDriver`:
-- `ip_address`: Robot IP (default: "192.168.1.235")
-- `xarm_type`: Robot model - "xarm5", "xarm6", or "xarm7" (default: "xarm6")
-- `control_frequency`: Control loop rate in Hz (default: 100.0)
-- `is_radian`: Use radians vs degrees (default: True)
-- `enable_on_start`: Auto-enable servo mode (default: True)
-- `velocity_control`: Use velocity vs position mode (default: False)
-
-## Testing
-
-### With Mock Hardware (No Physical Robot)
-
-```bash
-# Unit tests with mocked xArm hardware
-python tests/test_xarm_rt_driver.py
-```
-
-### With Real Hardware
-
-**⚠️ Note:** Interactive control and hardware tests require a physical xArm connected to the network. Interactive control, and sample_trajectory_generator are part of test suite, and will be deprecated.
-
-**Using Alfred Embodiment:**
-
-To test with real hardware using the current Alfred embodiment:
-
-1. **Turn on the Flowbase** (xArm controller)
-2. **SSH into dimensional-cpu-2:**
- ```
-3. **Verify PC is connected to the controller:**
- ```bash
- ping 192.168.1.235 # Should respond
- ```
-4. **Run the interactive control:**
- ```bash
- # Interactive control (recommended)
- venv/bin/python dimos/hardware/manipulators/xarm/interactive_control.py --ip 192.168.1.235
-
- # Run driver standalone
- venv/bin/python dimos/hardware/manipulators/xarm/test_xarm_driver.py --ip 192.168.1.235
-
- # Run automated test suite
- venv/bin/python dimos/hardware/manipulators/xarm/test_xarm_driver.py --ip 192.168.1.235 --run-tests
-
- # Specify xArm model type (if using xArm7)
- venv/bin/python dimos/hardware/manipulators/xarm/interactive_control.py --ip 192.168.1.235 --type xarm7
- ```
-
-## License
-
-Copyright 2025 Dimensional Inc. - Apache License 2.0
diff --git a/dimos/hardware/manipulators/xarm/__init__.py b/dimos/hardware/manipulators/xarm/__init__.py
index ef0c6763c1..343ebc4e0e 100644
--- a/dimos/hardware/manipulators/xarm/__init__.py
+++ b/dimos/hardware/manipulators/xarm/__init__.py
@@ -12,18 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-xArm Manipulator Driver Module
+"""XArm manipulator hardware backend.
-Real-time driver and components for xArm5/6/7 manipulators.
+Usage:
+ >>> from dimos.hardware.manipulators.xarm import XArmBackend
+ >>> backend = XArmBackend(ip="192.168.1.185", dof=6)
+ >>> backend.connect()
+ >>> positions = backend.read_joint_positions()
"""
-from dimos.hardware.manipulators.xarm.spec import ArmDriverSpec
-from dimos.hardware.manipulators.xarm.xarm_driver import XArmDriver
-from dimos.hardware.manipulators.xarm.xarm_wrapper import XArmSDKWrapper
+from dimos.hardware.manipulators.xarm.backend import XArmBackend
-__all__ = [
- "ArmDriverSpec",
- "XArmDriver",
- "XArmSDKWrapper",
-]
+__all__ = ["XArmBackend"]
diff --git a/dimos/hardware/manipulators/xarm/backend.py b/dimos/hardware/manipulators/xarm/backend.py
new file mode 100644
index 0000000000..9adcdca24f
--- /dev/null
+++ b/dimos/hardware/manipulators/xarm/backend.py
@@ -0,0 +1,392 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""XArm backend - implements ManipulatorBackend protocol.
+
+Handles all XArm SDK communication and unit conversion.
+"""
+
+import math
+
+from xarm.wrapper import XArmAPI
+
+from dimos.hardware.manipulators.spec import (
+ ControlMode,
+ JointLimits,
+ ManipulatorBackend,
+ ManipulatorInfo,
+)
+
+# XArm mode codes
+_XARM_MODE_POSITION = 0
+_XARM_MODE_SERVO_CARTESIAN = 1
+_XARM_MODE_JOINT_VELOCITY = 4
+_XARM_MODE_CARTESIAN_VELOCITY = 5
+_XARM_MODE_JOINT_TORQUE = 6
+
+
+class XArmBackend(ManipulatorBackend):
+ """XArm-specific backend.
+
+ Implements ManipulatorBackend protocol via duck typing.
+ No inheritance required - just matching method signatures.
+
+ Unit conversions:
+ - Angles: XArm uses degrees, we use radians
+ - Positions: XArm uses mm, we use meters
+ - Velocities: XArm uses deg/s, we use rad/s
+
+ TODO: Consider creating XArmPose/XArmVelocity types to encapsulate
+ unit conversions instead of helper methods. See ManipulatorPose discussion.
+ """
+
+ # =========================================================================
+ # Unit Conversions (SI <-> XArm units)
+ # =========================================================================
+
+ @staticmethod
+ def _m_to_mm(m: float) -> float:
+ return m * 1000.0
+
+ @staticmethod
+ def _mm_to_m(mm: float) -> float:
+ return mm / 1000.0
+
+ @staticmethod
+ def _rad_to_deg(rad: float) -> float:
+ return math.degrees(rad)
+
+ @staticmethod
+ def _deg_to_rad(deg: float) -> float:
+ return math.radians(deg)
+
+ @staticmethod
+ def _velocity_to_speed_mm(velocity: float) -> float:
+ """Convert 0-1 velocity fraction to mm/s (max ~500 mm/s)."""
+ return velocity * 500
+
+ def __init__(self, ip: str, dof: int = 6) -> None:
+ self._ip = ip
+ self._dof = dof
+ self._arm: XArmAPI | None = None
+ self._control_mode: ControlMode = ControlMode.POSITION
+
+ # =========================================================================
+ # Connection
+ # =========================================================================
+
+ def connect(self) -> bool:
+ """Connect to XArm via TCP/IP."""
+ try:
+ self._arm = XArmAPI(self._ip)
+ self._arm.connect()
+
+ if not self._arm.connected:
+ print(f"ERROR: XArm at {self._ip} not reachable (connected=False)")
+ return False
+
+ # Initialize to servo mode for high-frequency control
+ self._arm.set_mode(_XARM_MODE_SERVO_CARTESIAN) # Mode 1 = servo mode
+ self._arm.set_state(0)
+ self._control_mode = ControlMode.SERVO_POSITION
+
+ return True
+ except Exception as e:
+ print(f"ERROR: Failed to connect to XArm at {self._ip}: {e}")
+ return False
+
+ def disconnect(self) -> None:
+ """Disconnect from XArm."""
+ if self._arm:
+ self._arm.disconnect()
+ self._arm = None
+
+ def is_connected(self) -> bool:
+ """Check if connected to XArm."""
+ return self._arm is not None and self._arm.connected
+
+ # =========================================================================
+ # Info
+ # =========================================================================
+
+ def get_info(self) -> ManipulatorInfo:
+ """Get XArm information."""
+ return ManipulatorInfo(
+ vendor="UFACTORY",
+ model=f"xArm{self._dof}",
+ dof=self._dof,
+ )
+
+ def get_dof(self) -> int:
+ """Get degrees of freedom."""
+ return self._dof
+
+ def get_limits(self) -> JointLimits:
+ """Get joint limits (default XArm limits)."""
+ # XArm typical joint limits (varies by joint, using conservative values)
+ limit = 2 * math.pi
+ return JointLimits(
+ position_lower=[-limit] * self._dof,
+ position_upper=[limit] * self._dof,
+ velocity_max=[math.pi] * self._dof, # ~180 deg/s
+ )
+
+ # =========================================================================
+ # Control Mode
+ # =========================================================================
+
+ def set_control_mode(self, mode: ControlMode) -> bool:
+ """Set XArm control mode.
+
+ Note: XArm is flexible and often accepts commands without explicit
+ mode switching, but some operations require the correct mode.
+ """
+ if not self._arm:
+ return False
+
+ mode_map = {
+ ControlMode.POSITION: _XARM_MODE_POSITION,
+ ControlMode.SERVO_POSITION: _XARM_MODE_SERVO_CARTESIAN, # Mode 1 for high-freq
+ ControlMode.VELOCITY: _XARM_MODE_JOINT_VELOCITY,
+ ControlMode.TORQUE: _XARM_MODE_JOINT_TORQUE,
+ ControlMode.CARTESIAN: _XARM_MODE_SERVO_CARTESIAN,
+ ControlMode.CARTESIAN_VELOCITY: _XARM_MODE_CARTESIAN_VELOCITY,
+ }
+
+ xarm_mode = mode_map.get(mode)
+ if xarm_mode is None:
+ return False
+
+ code = self._arm.set_mode(xarm_mode)
+ if code == 0:
+ self._arm.set_state(0)
+ self._control_mode = mode
+ return True
+ return False
+
+ def get_control_mode(self) -> ControlMode:
+ """Get current control mode."""
+ return self._control_mode
+
+ # =========================================================================
+ # State Reading
+ # =========================================================================
+
+ def read_joint_positions(self) -> list[float]:
+ """Read joint positions (degrees -> radians)."""
+ if not self._arm:
+ raise RuntimeError("Not connected")
+
+ _, angles = self._arm.get_servo_angle()
+ if not angles:
+ raise RuntimeError("Failed to read joint positions")
+ return [math.radians(a) for a in angles[: self._dof]]
+
+ def read_joint_velocities(self) -> list[float]:
+ """Read joint velocities.
+
+ Note: XArm doesn't provide real-time velocity feedback directly.
+ Returns zeros. For velocity estimation, use finite differences
+ on positions in the driver.
+ """
+ return [0.0] * self._dof
+
+ def read_joint_efforts(self) -> list[float]:
+ """Read joint torques in Nm."""
+ if not self._arm:
+ return [0.0] * self._dof
+
+ code, torques = self._arm.get_joints_torque()
+ if code == 0 and torques:
+ return list(torques[: self._dof])
+ return [0.0] * self._dof
+
+ def read_state(self) -> dict[str, int]:
+ """Read robot state."""
+ if not self._arm:
+ return {"state": 0, "mode": 0}
+
+ return {
+ "state": self._arm.state,
+ "mode": self._arm.mode,
+ }
+
+ def read_error(self) -> tuple[int, str]:
+ """Read error code and message."""
+ if not self._arm:
+ return 0, ""
+
+ code = self._arm.error_code
+ if code == 0:
+ return 0, ""
+ return code, f"XArm error {code}"
+
+ # =========================================================================
+ # Motion Control (Joint Space)
+ # =========================================================================
+
+ def write_joint_positions(
+ self,
+ positions: list[float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Write joint positions for servo mode (radians -> degrees).
+
+ Uses set_servo_angle_j() for high-frequency servo control.
+ Requires mode 1 (servo mode) to be active.
+
+ Args:
+ positions: Target positions in radians
+ velocity: Speed as fraction of max (0-1) - not used in servo mode
+ """
+ if not self._arm:
+ return False
+
+ # Convert radians to degrees
+ angles = [math.degrees(p) for p in positions]
+
+ # Use set_servo_angle_j for high-frequency servo control (100Hz+)
+ # This only executes the last instruction, suitable for real-time control
+ code: int = self._arm.set_servo_angle_j(angles, speed=100, mvacc=500)
+ return code == 0
+
+ def write_joint_velocities(self, velocities: list[float]) -> bool:
+ """Write joint velocities (rad/s -> deg/s).
+
+ Note: Requires velocity mode to be active.
+ """
+ if not self._arm:
+ return False
+
+ # Convert rad/s to deg/s
+ speeds = [math.degrees(v) for v in velocities]
+ code: int = self._arm.vc_set_joint_velocity(speeds)
+ return code == 0
+
+ def write_stop(self) -> bool:
+ """Emergency stop."""
+ if not self._arm:
+ return False
+ code: int = self._arm.emergency_stop()
+ return code == 0
+
+ # =========================================================================
+ # Servo Control
+ # =========================================================================
+
+ def write_enable(self, enable: bool) -> bool:
+ """Enable or disable servos."""
+ if not self._arm:
+ return False
+ code: int = self._arm.motion_enable(enable=enable)
+ return code == 0
+
+ def read_enabled(self) -> bool:
+ """Check if servos are enabled."""
+ if not self._arm:
+ return False
+ # XArm state 0 = ready/enabled
+ state: int = self._arm.state
+ return state == 0
+
+ def write_clear_errors(self) -> bool:
+ """Clear error state."""
+ if not self._arm:
+ return False
+ code: int = self._arm.clean_error()
+ return code == 0
+
+ # =========================================================================
+ # Cartesian Control (Optional)
+ # =========================================================================
+
+ def read_cartesian_position(self) -> dict[str, float] | None:
+ """Read end-effector pose (mm -> meters, degrees -> radians)."""
+ if not self._arm:
+ return None
+
+ _, pose = self._arm.get_position()
+ if pose and len(pose) >= 6:
+ return {
+ "x": self._mm_to_m(pose[0]),
+ "y": self._mm_to_m(pose[1]),
+ "z": self._mm_to_m(pose[2]),
+ "roll": self._deg_to_rad(pose[3]),
+ "pitch": self._deg_to_rad(pose[4]),
+ "yaw": self._deg_to_rad(pose[5]),
+ }
+ return None
+
+ def write_cartesian_position(
+ self,
+ pose: dict[str, float],
+ velocity: float = 1.0,
+ ) -> bool:
+ """Write end-effector pose (meters -> mm, radians -> degrees)."""
+ if not self._arm:
+ return False
+
+ code: int = self._arm.set_position(
+ x=self._m_to_mm(pose.get("x", 0)),
+ y=self._m_to_mm(pose.get("y", 0)),
+ z=self._m_to_mm(pose.get("z", 0)),
+ roll=self._rad_to_deg(pose.get("roll", 0)),
+ pitch=self._rad_to_deg(pose.get("pitch", 0)),
+ yaw=self._rad_to_deg(pose.get("yaw", 0)),
+ speed=self._velocity_to_speed_mm(velocity),
+ wait=False,
+ )
+ return code == 0
+
+ # =========================================================================
+ # Gripper (Optional)
+ # =========================================================================
+
+ def read_gripper_position(self) -> float | None:
+ """Read gripper position (mm -> meters)."""
+ if not self._arm:
+ return None
+
+ result = self._arm.get_gripper_position()
+ code: int = result[0]
+ pos: float | None = result[1]
+ if code == 0 and pos is not None:
+ return pos / 1000.0 # mm -> m
+ return None
+
+ def write_gripper_position(self, position: float) -> bool:
+ """Write gripper position (meters -> mm)."""
+ if not self._arm:
+ return False
+
+ pos_mm = position * 1000.0 # m -> mm
+ code: int = self._arm.set_gripper_position(pos_mm)
+ return code == 0
+
+ # =========================================================================
+ # Force/Torque Sensor (Optional)
+ # =========================================================================
+
+ def read_force_torque(self) -> list[float] | None:
+ """Read F/T sensor data if available."""
+ if not self._arm:
+ return None
+
+ code, ft = self._arm.get_ft_sensor_data()
+ if code == 0 and ft:
+ return list(ft)
+ return None
+
+
+__all__ = ["XArmBackend"]
diff --git a/dimos/hardware/manipulators/xarm/components/__init__.py b/dimos/hardware/manipulators/xarm/components/__init__.py
deleted file mode 100644
index 4592560cda..0000000000
--- a/dimos/hardware/manipulators/xarm/components/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Component classes for XArmDriver."""
-
-from .gripper_control import GripperControlComponent
-from .kinematics import KinematicsComponent
-from .motion_control import MotionControlComponent
-from .state_queries import StateQueryComponent
-from .system_control import SystemControlComponent
-
-__all__ = [
- "GripperControlComponent",
- "KinematicsComponent",
- "MotionControlComponent",
- "StateQueryComponent",
- "SystemControlComponent",
-]
diff --git a/dimos/hardware/manipulators/xarm/components/gripper_control.py b/dimos/hardware/manipulators/xarm/components/gripper_control.py
deleted file mode 100644
index 13b8347978..0000000000
--- a/dimos/hardware/manipulators/xarm/components/gripper_control.py
+++ /dev/null
@@ -1,372 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Gripper Control Component for XArmDriver.
-
-Provides RPC methods for controlling various grippers:
-- Standard xArm gripper
-- Bio gripper
-- Vacuum gripper
-- Robotiq gripper
-"""
-
-from typing import TYPE_CHECKING, Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-if TYPE_CHECKING:
- from xarm.wrapper import XArmAPI
-
-logger = setup_logger()
-
-
-class GripperControlComponent:
- """
- Component providing gripper control RPC methods for XArmDriver.
-
- This component assumes the parent class has:
- - self.arm: XArmAPI instance
- - self.config: XArmDriverConfig instance
- """
-
- # Type hints for attributes expected from parent class
- arm: "XArmAPI"
- config: Any # Config dict accessed as object (dict with attribute access)
-
- # =========================================================================
- # Standard xArm Gripper
- # =========================================================================
-
- @rpc
- def set_gripper_enable(self, enable: int) -> tuple[int, str]:
- """Enable/disable gripper."""
- try:
- code = self.arm.set_gripper_enable(enable)
- return (
- code,
- f"Gripper {'enabled' if enable else 'disabled'}"
- if code == 0
- else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_gripper_mode(self, mode: int) -> tuple[int, str]:
- """Set gripper mode (0=location mode, 1=speed mode, 2=current mode)."""
- try:
- code = self.arm.set_gripper_mode(mode)
- return (code, f"Gripper mode set to {mode}" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_gripper_speed(self, speed: float) -> tuple[int, str]:
- """Set gripper speed (r/min)."""
- try:
- code = self.arm.set_gripper_speed(speed)
- return (code, f"Gripper speed set to {speed}" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_gripper_position(
- self,
- position: float,
- wait: bool = False,
- speed: float | None = None,
- timeout: float | None = None,
- ) -> tuple[int, str]:
- """
- Set gripper position.
-
- Args:
- position: Target position (0-850)
- wait: Wait for completion
- speed: Optional speed override
- timeout: Optional timeout for wait
- """
- try:
- code = self.arm.set_gripper_position(position, wait=wait, speed=speed, timeout=timeout)
- return (
- code,
- f"Gripper position set to {position}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def get_gripper_position(self) -> tuple[int, float | None]:
- """Get current gripper position."""
- try:
- code, position = self.arm.get_gripper_position()
- return (code, position if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_gripper_err_code(self) -> tuple[int, int | None]:
- """Get gripper error code."""
- try:
- code, err = self.arm.get_gripper_err_code()
- return (code, err if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def clean_gripper_error(self) -> tuple[int, str]:
- """Clear gripper error."""
- try:
- code = self.arm.clean_gripper_error()
- return (code, "Gripper error cleared" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Bio Gripper
- # =========================================================================
-
- @rpc
- def set_bio_gripper_enable(self, enable: int, wait: bool = True) -> tuple[int, str]:
- """Enable/disable bio gripper."""
- try:
- code = self.arm.set_bio_gripper_enable(enable, wait=wait)
- return (
- code,
- f"Bio gripper {'enabled' if enable else 'disabled'}"
- if code == 0
- else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_bio_gripper_speed(self, speed: int) -> tuple[int, str]:
- """Set bio gripper speed (1-100)."""
- try:
- code = self.arm.set_bio_gripper_speed(speed)
- return (
- code,
- f"Bio gripper speed set to {speed}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def open_bio_gripper(
- self, speed: int = 0, wait: bool = True, timeout: float = 5
- ) -> tuple[int, str]:
- """Open bio gripper."""
- try:
- code = self.arm.open_bio_gripper(speed=speed, wait=wait, timeout=timeout)
- return (code, "Bio gripper opened" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def close_bio_gripper(
- self, speed: int = 0, wait: bool = True, timeout: float = 5
- ) -> tuple[int, str]:
- """Close bio gripper."""
- try:
- code = self.arm.close_bio_gripper(speed=speed, wait=wait, timeout=timeout)
- return (code, "Bio gripper closed" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def get_bio_gripper_status(self) -> tuple[int, int | None]:
- """Get bio gripper status."""
- try:
- code, status = self.arm.get_bio_gripper_status()
- return (code, status if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_bio_gripper_error(self) -> tuple[int, int | None]:
- """Get bio gripper error code."""
- try:
- code, error = self.arm.get_bio_gripper_error()
- return (code, error if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def clean_bio_gripper_error(self) -> tuple[int, str]:
- """Clear bio gripper error."""
- try:
- code = self.arm.clean_bio_gripper_error()
- return (code, "Bio gripper error cleared" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Vacuum Gripper
- # =========================================================================
-
- @rpc
- def set_vacuum_gripper(self, on: int) -> tuple[int, str]:
- """Turn vacuum gripper on/off (0=off, 1=on)."""
- try:
- code = self.arm.set_vacuum_gripper(on)
- return (
- code,
- f"Vacuum gripper {'on' if on else 'off'}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def get_vacuum_gripper(self) -> tuple[int, int | None]:
- """Get vacuum gripper state."""
- try:
- code, state = self.arm.get_vacuum_gripper()
- return (code, state if code == 0 else None)
- except Exception:
- return (-1, None)
-
- # =========================================================================
- # Robotiq Gripper
- # =========================================================================
-
- @rpc
- def robotiq_reset(self) -> tuple[int, str]:
- """Reset Robotiq gripper."""
- try:
- code = self.arm.robotiq_reset()
- return (code, "Robotiq gripper reset" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def robotiq_set_activate(self, wait: bool = True, timeout: float = 3) -> tuple[int, str]:
- """Activate Robotiq gripper."""
- try:
- code = self.arm.robotiq_set_activate(wait=wait, timeout=timeout)
- return (code, "Robotiq gripper activated" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def robotiq_set_position(
- self,
- position: int,
- speed: int = 0xFF,
- force: int = 0xFF,
- wait: bool = True,
- timeout: float = 5,
- ) -> tuple[int, str]:
- """
- Set Robotiq gripper position.
-
- Args:
- position: Target position (0-255, 0=open, 255=closed)
- speed: Gripper speed (0-255)
- force: Gripper force (0-255)
- wait: Wait for completion
- timeout: Timeout for wait
- """
- try:
- code = self.arm.robotiq_set_position(
- position, speed=speed, force=force, wait=wait, timeout=timeout
- )
- return (
- code,
- f"Robotiq position set to {position}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def robotiq_open(
- self, speed: int = 0xFF, force: int = 0xFF, wait: bool = True, timeout: float = 5
- ) -> tuple[int, str]:
- """Open Robotiq gripper."""
- try:
- code = self.arm.robotiq_open(speed=speed, force=force, wait=wait, timeout=timeout)
- return (code, "Robotiq gripper opened" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def robotiq_close(
- self, speed: int = 0xFF, force: int = 0xFF, wait: bool = True, timeout: float = 5
- ) -> tuple[int, str]:
- """Close Robotiq gripper."""
- try:
- code = self.arm.robotiq_close(speed=speed, force=force, wait=wait, timeout=timeout)
- return (code, "Robotiq gripper closed" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def robotiq_get_status(self) -> tuple[int, dict[str, Any] | None]:
- """Get Robotiq gripper status."""
- try:
- ret = self.arm.robotiq_get_status()
- if isinstance(ret, tuple) and len(ret) >= 2:
- code = ret[0]
- if code == 0:
- # Return status as dict if successful
- status = {
- "gOBJ": ret[1] if len(ret) > 1 else None, # Object detection status
- "gSTA": ret[2] if len(ret) > 2 else None, # Gripper status
- "gGTO": ret[3] if len(ret) > 3 else None, # Go to requested position
- "gACT": ret[4] if len(ret) > 4 else None, # Activation status
- "kFLT": ret[5] if len(ret) > 5 else None, # Fault status
- "gFLT": ret[6] if len(ret) > 6 else None, # Fault status
- "gPR": ret[7] if len(ret) > 7 else None, # Requested position echo
- "gPO": ret[8] if len(ret) > 8 else None, # Actual position
- "gCU": ret[9] if len(ret) > 9 else None, # Current
- }
- return (code, status)
- return (code, None)
- return (-1, None)
- except Exception as e:
- logger.error(f"robotiq_get_status failed: {e}")
- return (-1, None)
-
- # =========================================================================
- # Lite6 Gripper
- # =========================================================================
-
- @rpc
- def open_lite6_gripper(self) -> tuple[int, str]:
- """Open Lite6 gripper."""
- try:
- code = self.arm.open_lite6_gripper()
- return (code, "Lite6 gripper opened" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def close_lite6_gripper(self) -> tuple[int, str]:
- """Close Lite6 gripper."""
- try:
- code = self.arm.close_lite6_gripper()
- return (code, "Lite6 gripper closed" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def stop_lite6_gripper(self) -> tuple[int, str]:
- """Stop Lite6 gripper."""
- try:
- code = self.arm.stop_lite6_gripper()
- return (code, "Lite6 gripper stopped" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
diff --git a/dimos/hardware/manipulators/xarm/components/kinematics.py b/dimos/hardware/manipulators/xarm/components/kinematics.py
deleted file mode 100644
index c29007a426..0000000000
--- a/dimos/hardware/manipulators/xarm/components/kinematics.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Kinematics Component for XArmDriver.
-
-Provides RPC methods for kinematic calculations including:
-- Forward kinematics
-- Inverse kinematics
-"""
-
-from typing import TYPE_CHECKING, Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-if TYPE_CHECKING:
- from xarm.wrapper import XArmAPI
-
-logger = setup_logger()
-
-
-class KinematicsComponent:
- """
- Component providing kinematics RPC methods for XArmDriver.
-
- This component assumes the parent class has:
- - self.arm: XArmAPI instance
- - self.config: XArmDriverConfig instance
- """
-
- # Type hints for attributes expected from parent class
- arm: "XArmAPI"
- config: Any # Config dict accessed as object (dict with attribute access)
-
- @rpc
- def get_inverse_kinematics(self, pose: list[float]) -> tuple[int, list[float] | None]:
- """
- Compute inverse kinematics.
-
- Args:
- pose: [x, y, z, roll, pitch, yaw]
-
- Returns:
- Tuple of (code, joint_angles)
- """
- try:
- code, angles = self.arm.get_inverse_kinematics(
- pose, input_is_radian=self.config.is_radian, return_is_radian=self.config.is_radian
- )
- return (code, list(angles) if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_forward_kinematics(self, angles: list[float]) -> tuple[int, list[float] | None]:
- """
- Compute forward kinematics.
-
- Args:
- angles: Joint angles
-
- Returns:
- Tuple of (code, pose)
- """
- try:
- code, pose = self.arm.get_forward_kinematics(
- angles,
- input_is_radian=self.config.is_radian,
- return_is_radian=self.config.is_radian,
- )
- return (code, list(pose) if code == 0 else None)
- except Exception:
- return (-1, None)
diff --git a/dimos/hardware/manipulators/xarm/components/motion_control.py b/dimos/hardware/manipulators/xarm/components/motion_control.py
deleted file mode 100644
index 64aaa861e0..0000000000
--- a/dimos/hardware/manipulators/xarm/components/motion_control.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Motion Control Component for XArmDriver.
-
-Provides RPC methods for motion control operations including:
-- Joint position control
-- Joint velocity control
-- Cartesian position control
-- Home positioning
-"""
-
-import math
-import threading
-from typing import TYPE_CHECKING, Any
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-if TYPE_CHECKING:
- from xarm.wrapper import XArmAPI
-
-logger = setup_logger()
-
-
-class MotionControlComponent:
- """
- Component providing motion control RPC methods for XArmDriver.
-
- This component assumes the parent class has:
- - self.arm: XArmAPI instance
- - self.config: XArmDriverConfig instance
- - self._joint_cmd_lock: threading.Lock
- - self._joint_cmd_: Optional[list[float]]
- """
-
- # Type hints for attributes expected from parent class
- arm: "XArmAPI"
- config: Any # Config dict accessed as object (dict with attribute access)
- _joint_cmd_lock: threading.Lock
- _joint_cmd_: list[float] | None
-
- @rpc
- def set_joint_angles(self, angles: list[float]) -> tuple[int, str]:
- """
- Set joint angles (RPC method).
-
- Args:
- angles: List of joint angles (in radians if is_radian=True)
-
- Returns:
- Tuple of (code, message)
- """
- try:
- code = self.arm.set_servo_angle_j(angles=angles, is_radian=self.config.is_radian)
- msg = "Success" if code == 0 else f"Error code: {code}"
- return (code, msg)
- except Exception as e:
- logger.error(f"set_joint_angles failed: {e}")
- return (-1, str(e))
-
- @rpc
- def set_joint_velocities(self, velocities: list[float]) -> tuple[int, str]:
- """
- Set joint velocities (RPC method).
- Note: Requires velocity control mode.
-
- Args:
- velocities: List of joint velocities (rad/s)
-
- Returns:
- Tuple of (code, message)
- """
- try:
- # For velocity control, you would use vc_set_joint_velocity
- # This requires mode 4 (joint velocity control)
- code = self.arm.vc_set_joint_velocity(
- speeds=velocities, is_radian=self.config.is_radian
- )
- msg = "Success" if code == 0 else f"Error code: {code}"
- return (code, msg)
- except Exception as e:
- logger.error(f"set_joint_velocities failed: {e}")
- return (-1, str(e))
-
- @rpc
- def set_position(self, position: list[float], wait: bool = False) -> tuple[int, str]:
- """
- Set TCP position [x, y, z, roll, pitch, yaw].
-
- Args:
- position: Target position
- wait: Wait for motion to complete
-
- Returns:
- Tuple of (code, message)
- """
- try:
- code = self.arm.set_position(*position, is_radian=self.config.is_radian, wait=wait)
- return (code, "Success" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def move_gohome(self, wait: bool = False) -> tuple[int, str]:
- """Move to home position."""
- try:
- code = self.arm.move_gohome(wait=wait, is_radian=self.config.is_radian)
- return (code, "Moving home" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_joint_command(self, positions: list[float]) -> tuple[int, str]:
- """
- Manually set the joint command (for testing).
- This updates the shared joint_cmd that the control loop reads.
-
- Args:
- positions: List of joint positions in radians
-
- Returns:
- Tuple of (code, message)
- """
- try:
- if len(positions) != self.config.num_joints:
- return (-1, f"Expected {self.config.num_joints} positions, got {len(positions)}")
-
- with self._joint_cmd_lock:
- self._joint_cmd_ = list(positions)
-
- logger.info(f"✓ Joint command set: {[f'{math.degrees(p):.2f}°' for p in positions]}")
- return (0, "Joint command updated")
- except Exception as e:
- return (-1, str(e))
diff --git a/dimos/hardware/manipulators/xarm/components/state_queries.py b/dimos/hardware/manipulators/xarm/components/state_queries.py
deleted file mode 100644
index 5615763cc4..0000000000
--- a/dimos/hardware/manipulators/xarm/components/state_queries.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-State Query Component for XArmDriver.
-
-Provides RPC methods for querying robot state including:
-- Joint state
-- Robot state
-- Cartesian position
-- Firmware version
-"""
-
-import threading
-from typing import TYPE_CHECKING, Any
-
-from dimos.core import rpc
-from dimos.msgs.sensor_msgs import JointState, RobotState
-from dimos.utils.logging_config import setup_logger
-
-if TYPE_CHECKING:
- from xarm.wrapper import XArmAPI
-
-logger = setup_logger()
-
-
-class StateQueryComponent:
- """
- Component providing state query RPC methods for XArmDriver.
-
- This component assumes the parent class has:
- - self.arm: XArmAPI instance
- - self.config: XArmDriverConfig instance
- - self._joint_state_lock: threading.Lock
- - self._joint_states_: Optional[JointState]
- - self._robot_state_: Optional[RobotState]
- """
-
- # Type hints for attributes expected from parent class
- arm: "XArmAPI"
- config: Any # Config dict accessed as object (dict with attribute access)
- _joint_state_lock: threading.Lock
- _joint_states_: JointState | None
- _robot_state_: RobotState | None
-
- @rpc
- def get_joint_state(self) -> JointState | None:
- """
- Get the current joint state (RPC method).
-
- Returns:
- Current JointState or None
- """
- with self._joint_state_lock:
- return self._joint_states_
-
- @rpc
- def get_robot_state(self) -> RobotState | None:
- """
- Get the current robot state (RPC method).
-
- Returns:
- Current RobotState or None
- """
- with self._joint_state_lock:
- return self._robot_state_
-
- @rpc
- def get_position(self) -> tuple[int, list[float] | None]:
- """
- Get TCP position [x, y, z, roll, pitch, yaw].
-
- Returns:
- Tuple of (code, position)
- """
- try:
- code, position = self.arm.get_position(is_radian=self.config.is_radian)
- return (code, list(position) if code == 0 else None)
- except Exception as e:
- logger.error(f"get_position failed: {e}")
- return (-1, None)
-
- @rpc
- def get_version(self) -> tuple[int, str | None]:
- """Get firmware version."""
- try:
- code, version = self.arm.get_version()
- return (code, version if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_servo_angle(self) -> tuple[int, list[float] | None]:
- """Get joint angles."""
- try:
- code, angles = self.arm.get_servo_angle(is_radian=self.config.is_radian)
- return (code, list(angles) if code == 0 else None)
- except Exception as e:
- logger.error(f"get_servo_angle failed: {e}")
- return (-1, None)
-
- @rpc
- def get_position_aa(self) -> tuple[int, list[float] | None]:
- """Get TCP position in axis-angle format."""
- try:
- code, position = self.arm.get_position_aa(is_radian=self.config.is_radian)
- return (code, list(position) if code == 0 else None)
- except Exception as e:
- logger.error(f"get_position_aa failed: {e}")
- return (-1, None)
-
- # =========================================================================
- # Robot State Queries
- # =========================================================================
-
- @rpc
- def get_state(self) -> tuple[int, int | None]:
- """Get robot state (0=ready, 3=pause, 4=stop)."""
- try:
- code, state = self.arm.get_state()
- return (code, state if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_cmdnum(self) -> tuple[int, int | None]:
- """Get command queue length."""
- try:
- code, cmdnum = self.arm.get_cmdnum()
- return (code, cmdnum if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_err_warn_code(self) -> tuple[int, list[int] | None]:
- """Get error and warning codes."""
- try:
- err_warn = [0, 0]
- code = self.arm.get_err_warn_code(err_warn)
- return (code, err_warn if code == 0 else None)
- except Exception:
- return (-1, None)
-
- # =========================================================================
- # Force/Torque Sensor Queries
- # =========================================================================
-
- @rpc
- def get_ft_sensor_data(self) -> tuple[int, list[float] | None]:
- """Get force/torque sensor data [fx, fy, fz, tx, ty, tz]."""
- try:
- code, ft_data = self.arm.get_ft_sensor_data()
- return (code, list(ft_data) if code == 0 else None)
- except Exception as e:
- logger.error(f"get_ft_sensor_data failed: {e}")
- return (-1, None)
-
- @rpc
- def get_ft_sensor_error(self) -> tuple[int, int | None]:
- """Get FT sensor error code."""
- try:
- code, error = self.arm.get_ft_sensor_error()
- return (code, error if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_ft_sensor_mode(self) -> tuple[int, int | None]:
- """Get FT sensor application mode."""
- try:
- code, mode = self.arm.get_ft_sensor_app_get()
- return (code, mode if code == 0 else None)
- except Exception:
- return (-1, None)
diff --git a/dimos/hardware/manipulators/xarm/components/system_control.py b/dimos/hardware/manipulators/xarm/components/system_control.py
deleted file mode 100644
index a04e9a94a0..0000000000
--- a/dimos/hardware/manipulators/xarm/components/system_control.py
+++ /dev/null
@@ -1,555 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-System Control Component for XArmDriver.
-
-Provides RPC methods for system-level control operations including:
-- Mode control (servo, velocity)
-- State management
-- Error handling
-- Emergency stop
-"""
-
-from typing import TYPE_CHECKING, Any, Protocol
-
-from dimos.core import rpc
-from dimos.utils.logging_config import setup_logger
-
-if TYPE_CHECKING:
- from xarm.wrapper import XArmAPI
-
- class XArmConfig(Protocol):
- """Protocol for XArm configuration."""
-
- is_radian: bool
- velocity_control: bool
-
-
-logger = setup_logger()
-
-
-class SystemControlComponent:
- """
- Component providing system control RPC methods for XArmDriver.
-
- This component assumes the parent class has:
- - self.arm: XArmAPI instance
- - self.config: XArmDriverConfig instance
- """
-
- # Type hints for attributes expected from parent class
- arm: "XArmAPI"
- config: Any # Should be XArmConfig but accessed as dict
-
- @rpc
- def enable_servo_mode(self) -> tuple[int, str]:
- """
- Enable servo mode (mode 1).
- Required for set_servo_angle_j to work.
-
- Returns:
- Tuple of (code, message)
- """
- try:
- code = self.arm.set_mode(1)
- if code == 0:
- logger.info("Servo mode enabled")
- return (code, "Servo mode enabled")
- else:
- logger.warning(f"Failed to enable servo mode: code={code}")
- return (code, f"Error code: {code}")
- except Exception as e:
- logger.error(f"enable_servo_mode failed: {e}")
- return (-1, str(e))
-
- @rpc
- def disable_servo_mode(self) -> tuple[int, str]:
- """
- Disable servo mode (set to position mode).
-
- Returns:
- Tuple of (code, message)
- """
- try:
- code = self.arm.set_mode(0)
- if code == 0:
- logger.info("Servo mode disabled (position mode)")
- return (code, "Position mode enabled")
- else:
- logger.warning(f"Failed to disable servo mode: code={code}")
- return (code, f"Error code: {code}")
- except Exception as e:
- logger.error(f"disable_servo_mode failed: {e}")
- return (-1, str(e))
-
- @rpc
- def enable_velocity_control_mode(self) -> tuple[int, str]:
- """
- Enable velocity control mode (mode 4).
- Required for vc_set_joint_velocity to work.
-
- Returns:
- Tuple of (code, message)
- """
- try:
- # IMPORTANT: Set config flag BEFORE changing robot mode
- # This prevents control loop from sending wrong command type during transition
- self.config.velocity_control = True
-
- # Step 1: Set mode to 4 (velocity control)
- code = self.arm.set_mode(4)
- if code != 0:
- logger.warning(f"Failed to set mode to 4: code={code}")
- self.config.velocity_control = False # Revert on failure
- return (code, f"Failed to set mode: code={code}")
-
- # Step 2: Set state to 0 (ready/sport mode) - this activates the mode!
- code = self.arm.set_state(0)
- if code == 0:
- logger.info("Velocity control mode enabled (mode=4, state=0)")
- return (code, "Velocity control mode enabled")
- else:
- logger.warning(f"Failed to set state to 0: code={code}")
- self.config.velocity_control = False # Revert on failure
- return (code, f"Failed to set state: code={code}")
- except Exception as e:
- logger.error(f"enable_velocity_control_mode failed: {e}")
- self.config.velocity_control = False # Revert on exception
- return (-1, str(e))
-
- @rpc
- def disable_velocity_control_mode(self) -> tuple[int, str]:
- """
- Disable velocity control mode and return to position control (mode 1).
-
- Returns:
- Tuple of (code, message)
- """
- try:
- # IMPORTANT: Set config flag BEFORE changing robot mode
- # This prevents control loop from sending velocity commands after mode change
- self.config.velocity_control = False
-
- # Step 1: Clear any errors that may have occurred
- self.arm.clean_error()
- self.arm.clean_warn()
-
- # Step 2: Set mode to 1 (servo/position control)
- code = self.arm.set_mode(1)
- if code != 0:
- logger.warning(f"Failed to set mode to 1: code={code}")
- self.config.velocity_control = True # Revert on failure
- return (code, f"Failed to set mode: code={code}")
-
- # Step 3: Set state to 0 (ready) - CRITICAL for accepting new commands
- code = self.arm.set_state(0)
- if code == 0:
- logger.info("Position control mode enabled (state=0, mode=1)")
- return (code, "Position control mode enabled")
- else:
- logger.warning(f"Failed to set state to 0: code={code}")
- self.config.velocity_control = True # Revert on failure
- return (code, f"Failed to set state: code={code}")
- except Exception as e:
- logger.error(f"disable_velocity_control_mode failed: {e}")
- self.config.velocity_control = True # Revert on exception
- return (-1, str(e))
-
- @rpc
- def motion_enable(self, enable: bool = True) -> tuple[int, str]:
- """Enable or disable arm motion."""
- try:
- code = self.arm.motion_enable(enable=enable)
- msg = f"Motion {'enabled' if enable else 'disabled'}"
- return (code, msg if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_state(self, state: int) -> tuple[int, str]:
- """
- Set robot state.
-
- Args:
- state: 0=ready, 3=pause, 4=stop
- """
- try:
- code = self.arm.set_state(state=state)
- return (code, "Success" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def clean_error(self) -> tuple[int, str]:
- """Clear error codes."""
- try:
- code = self.arm.clean_error()
- return (code, "Errors cleared" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def clean_warn(self) -> tuple[int, str]:
- """Clear warning codes."""
- try:
- code = self.arm.clean_warn()
- return (code, "Warnings cleared" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def emergency_stop(self) -> tuple[int, str]:
- """Emergency stop the arm."""
- try:
- code = self.arm.emergency_stop()
- return (code, "Emergency stop" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Configuration & Persistence
- # =========================================================================
-
- @rpc
- def clean_conf(self) -> tuple[int, str]:
- """Clean configuration."""
- try:
- code = self.arm.clean_conf()
- return (code, "Configuration cleaned" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def save_conf(self) -> tuple[int, str]:
- """Save current configuration to robot."""
- try:
- code = self.arm.save_conf()
- return (code, "Configuration saved" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def reload_dynamics(self) -> tuple[int, str]:
- """Reload dynamics parameters."""
- try:
- code = self.arm.reload_dynamics()
- return (code, "Dynamics reloaded" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Mode & State Control
- # =========================================================================
-
- @rpc
- def set_mode(self, mode: int) -> tuple[int, str]:
- """
- Set control mode.
-
- Args:
- mode: 0=position, 1=servo, 4=velocity, etc.
- """
- try:
- code = self.arm.set_mode(mode)
- return (code, f"Mode set to {mode}" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Collision & Safety
- # =========================================================================
-
- @rpc
- def set_collision_sensitivity(self, sensitivity: int) -> tuple[int, str]:
- """Set collision sensitivity (0-5, 0=least sensitive)."""
- try:
- code = self.arm.set_collision_sensitivity(sensitivity)
- return (
- code,
- f"Collision sensitivity set to {sensitivity}"
- if code == 0
- else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_teach_sensitivity(self, sensitivity: int) -> tuple[int, str]:
- """Set teach sensitivity (1-5)."""
- try:
- code = self.arm.set_teach_sensitivity(sensitivity)
- return (
- code,
- f"Teach sensitivity set to {sensitivity}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_collision_rebound(self, enable: int) -> tuple[int, str]:
- """Enable/disable collision rebound (0=disable, 1=enable)."""
- try:
- code = self.arm.set_collision_rebound(enable)
- return (
- code,
- f"Collision rebound {'enabled' if enable else 'disabled'}"
- if code == 0
- else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_self_collision_detection(self, enable: int) -> tuple[int, str]:
- """Enable/disable self collision detection."""
- try:
- code = self.arm.set_self_collision_detection(enable)
- return (
- code,
- f"Self collision detection {'enabled' if enable else 'disabled'}"
- if code == 0
- else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Reduced Mode & Boundaries
- # =========================================================================
-
- @rpc
- def set_reduced_mode(self, enable: int) -> tuple[int, str]:
- """Enable/disable reduced mode."""
- try:
- code = self.arm.set_reduced_mode(enable)
- return (
- code,
- f"Reduced mode {'enabled' if enable else 'disabled'}"
- if code == 0
- else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_reduced_max_tcp_speed(self, speed: float) -> tuple[int, str]:
- """Set maximum TCP speed in reduced mode."""
- try:
- code = self.arm.set_reduced_max_tcp_speed(speed)
- return (
- code,
- f"Reduced max TCP speed set to {speed}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_reduced_max_joint_speed(self, speed: float) -> tuple[int, str]:
- """Set maximum joint speed in reduced mode."""
- try:
- code = self.arm.set_reduced_max_joint_speed(speed)
- return (
- code,
- f"Reduced max joint speed set to {speed}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_fence_mode(self, enable: int) -> tuple[int, str]:
- """Enable/disable fence mode."""
- try:
- code = self.arm.set_fence_mode(enable)
- return (
- code,
- f"Fence mode {'enabled' if enable else 'disabled'}"
- if code == 0
- else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # TCP & Dynamics Configuration
- # =========================================================================
-
- @rpc
- def set_tcp_offset(self, offset: list[float]) -> tuple[int, str]:
- """Set TCP offset [x, y, z, roll, pitch, yaw]."""
- try:
- code = self.arm.set_tcp_offset(offset)
- return (code, "TCP offset set" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_tcp_load(self, weight: float, center_of_gravity: list[float]) -> tuple[int, str]:
- """Set TCP load (payload)."""
- try:
- code = self.arm.set_tcp_load(weight, center_of_gravity)
- return (code, f"TCP load set: {weight}kg" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_gravity_direction(self, direction: list[float]) -> tuple[int, str]:
- """Set gravity direction vector."""
- try:
- code = self.arm.set_gravity_direction(direction)
- return (code, "Gravity direction set" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_world_offset(self, offset: list[float]) -> tuple[int, str]:
- """Set world coordinate offset."""
- try:
- code = self.arm.set_world_offset(offset)
- return (code, "World offset set" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Motion Parameters
- # =========================================================================
-
- @rpc
- def set_tcp_jerk(self, jerk: float) -> tuple[int, str]:
- """Set TCP jerk (mm/s³)."""
- try:
- code = self.arm.set_tcp_jerk(jerk)
- return (code, f"TCP jerk set to {jerk}" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_tcp_maxacc(self, acc: float) -> tuple[int, str]:
- """Set TCP maximum acceleration (mm/s²)."""
- try:
- code = self.arm.set_tcp_maxacc(acc)
- return (
- code,
- f"TCP max acceleration set to {acc}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_joint_jerk(self, jerk: float) -> tuple[int, str]:
- """Set joint jerk (rad/s³ or °/s³)."""
- try:
- code = self.arm.set_joint_jerk(jerk, is_radian=self.config.is_radian)
- return (code, f"Joint jerk set to {jerk}" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_joint_maxacc(self, acc: float) -> tuple[int, str]:
- """Set joint maximum acceleration (rad/s² or °/s²)."""
- try:
- code = self.arm.set_joint_maxacc(acc, is_radian=self.config.is_radian)
- return (
- code,
- f"Joint max acceleration set to {acc}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
-
- @rpc
- def set_pause_time(self, seconds: float) -> tuple[int, str]:
- """Set pause time for motion commands."""
- try:
- code = self.arm.set_pause_time(seconds)
- return (code, f"Pause time set to {seconds}s" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Digital I/O (Tool GPIO)
- # =========================================================================
-
- @rpc
- def get_tgpio_digital(self, io_num: int) -> tuple[int, int | None]:
- """Get tool GPIO digital input value."""
- try:
- code, value = self.arm.get_tgpio_digital(io_num)
- return (code, value if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def set_tgpio_digital(self, io_num: int, value: int) -> tuple[int, str]:
- """Set tool GPIO digital output value (0 or 1)."""
- try:
- code = self.arm.set_tgpio_digital(io_num, value)
- return (code, f"TGPIO {io_num} set to {value}" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Digital I/O (Controller GPIO)
- # =========================================================================
-
- @rpc
- def get_cgpio_digital(self, io_num: int) -> tuple[int, int | None]:
- """Get controller GPIO digital input value."""
- try:
- code, value = self.arm.get_cgpio_digital(io_num)
- return (code, value if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def set_cgpio_digital(self, io_num: int, value: int) -> tuple[int, str]:
- """Set controller GPIO digital output value (0 or 1)."""
- try:
- code = self.arm.set_cgpio_digital(io_num, value)
- return (code, f"CGPIO {io_num} set to {value}" if code == 0 else f"Error code: {code}")
- except Exception as e:
- return (-1, str(e))
-
- # =========================================================================
- # Analog I/O
- # =========================================================================
-
- @rpc
- def get_tgpio_analog(self, io_num: int) -> tuple[int, float | None]:
- """Get tool GPIO analog input value."""
- try:
- code, value = self.arm.get_tgpio_analog(io_num)
- return (code, value if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def get_cgpio_analog(self, io_num: int) -> tuple[int, float | None]:
- """Get controller GPIO analog input value."""
- try:
- code, value = self.arm.get_cgpio_analog(io_num)
- return (code, value if code == 0 else None)
- except Exception:
- return (-1, None)
-
- @rpc
- def set_cgpio_analog(self, io_num: int, value: float) -> tuple[int, str]:
- """Set controller GPIO analog output value."""
- try:
- code = self.arm.set_cgpio_analog(io_num, value)
- return (
- code,
- f"CGPIO analog {io_num} set to {value}" if code == 0 else f"Error code: {code}",
- )
- except Exception as e:
- return (-1, str(e))
diff --git a/dimos/hardware/manipulators/xarm/spec.py b/dimos/hardware/manipulators/xarm/spec.py
deleted file mode 100644
index 625f036a0b..0000000000
--- a/dimos/hardware/manipulators/xarm/spec.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from typing import Protocol
-
-from dimos.core import In, Out
-from dimos.msgs.geometry_msgs import WrenchStamped
-from dimos.msgs.sensor_msgs import JointCommand, JointState
-
-
-@dataclass
-class RobotState:
- """Custom message containing full robot state (deprecated - use RobotStateMsg)."""
-
- state: int = 0 # Robot state (0: ready, 3: paused, 4: stopped, etc.)
- mode: int = 0 # Control mode (0: position, 1: servo, 4: joint velocity, 5: cartesian velocity)
- error_code: int = 0 # Error code
- warn_code: int = 0 # Warning code
- cmdnum: int = 0 # Command queue length
- mt_brake: int = 0 # Motor brake state
- mt_able: int = 0 # Motor enable state
-
-
-class ArmDriverSpec(Protocol):
- """Protocol specification for xArm manipulator driver.
-
- Compatible with xArm5, xArm6, and xArm7 models.
- """
-
- # Input topics (commands)
- joint_position_command: In[JointCommand] # Desired joint positions (radians)
- joint_velocity_command: In[JointCommand] # Desired joint velocities (rad/s)
-
- # Output topics
- joint_state: Out[JointState] # Current joint positions, velocities, and efforts
- robot_state: Out[RobotState] # Full robot state (errors, modes, etc.)
- ft_ext: Out[WrenchStamped] # External force/torque (compensated)
- ft_raw: Out[WrenchStamped] # Raw force/torque sensor data
-
- # RPC Methods
- def set_joint_angles(self, angles: list[float]) -> tuple[int, str]: ...
-
- def set_joint_velocities(self, velocities: list[float]) -> tuple[int, str]: ...
-
- def get_joint_state(self) -> JointState: ...
-
- def get_robot_state(self) -> RobotState: ...
-
- def enable_servo_mode(self) -> tuple[int, str]: ...
-
- def disable_servo_mode(self) -> tuple[int, str]: ...
diff --git a/dimos/hardware/manipulators/xarm/xarm_blueprints.py b/dimos/hardware/manipulators/xarm/xarm_blueprints.py
deleted file mode 100644
index 4e84c9c991..0000000000
--- a/dimos/hardware/manipulators/xarm/xarm_blueprints.py
+++ /dev/null
@@ -1,260 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Blueprints for xArm manipulator control using component-based architecture.
-
-This module provides declarative blueprints for configuring xArm with the new
-generalized component-based driver architecture.
-
-Usage:
- # Run via CLI:
- dimos run xarm-servo # Driver only
- dimos run xarm-trajectory # Driver + Joint trajectory controller
- dimos run xarm-cartesian # Driver + Cartesian motion controller
-
- # Or programmatically:
- from dimos.hardware.manipulators.xarm.xarm_blueprints import xarm_trajectory
- coordinator = xarm_trajectory.build()
- coordinator.loop()
-"""
-
-from typing import Any
-
-from dimos.core.blueprints import autoconnect
-from dimos.core.transport import LCMTransport
-from dimos.hardware.manipulators.xarm.xarm_driver import xarm_driver as xarm_driver_blueprint
-from dimos.manipulation.control import cartesian_motion_controller, joint_trajectory_controller
-from dimos.msgs.geometry_msgs import PoseStamped
-from dimos.msgs.sensor_msgs import (
- JointCommand,
- JointState,
- RobotState,
-)
-from dimos.msgs.trajectory_msgs import JointTrajectory
-
-
-# Create a blueprint wrapper for the component-based driver
-def xarm_driver(**config: Any) -> Any:
- """Create a blueprint for XArmDriver.
-
- Args:
- **config: Configuration parameters passed to XArmDriver
- - ip: IP address of XArm controller (default: "192.168.1.210")
- - dof: Degrees of freedom - 5, 6, or 7 (default: 6)
- - has_gripper: Whether gripper is attached (default: False)
- - has_force_torque: Whether F/T sensor is attached (default: False)
- - control_rate: Control loop + joint feedback rate in Hz (default: 100)
- - monitor_rate: Robot state monitoring rate in Hz (default: 10)
-
- Returns:
- Blueprint configuration for XArmDriver
- """
- # Set defaults
- config.setdefault("ip", "192.168.1.210")
- config.setdefault("dof", 6)
- config.setdefault("has_gripper", False)
- config.setdefault("has_force_torque", False)
- config.setdefault("control_rate", 100)
- config.setdefault("monitor_rate", 10)
-
- # Return the xarm_driver blueprint with the config
- return xarm_driver_blueprint(**config)
-
-
-# =============================================================================
-# xArm6 Servo Control Blueprint
-# =============================================================================
-# XArmDriver configured for servo control mode using component-based architecture.
-# Publishes joint states and robot state, listens for joint commands.
-# =============================================================================
-
-xarm_servo = xarm_driver(
- ip="192.168.1.210",
- dof=6, # XArm6
- has_gripper=False,
- has_force_torque=False,
- control_rate=100,
- monitor_rate=10,
-).transports(
- {
- # Joint state feedback (position, velocity, effort)
- ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState),
- # Robot state feedback (mode, state, errors)
- ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState),
- # Position commands input
- ("joint_position_command", JointCommand): LCMTransport(
- "/xarm/joint_position_command", JointCommand
- ),
- # Velocity commands input
- ("joint_velocity_command", JointCommand): LCMTransport(
- "/xarm/joint_velocity_command", JointCommand
- ),
- }
-)
-
-# =============================================================================
-# xArm7 Servo Control Blueprint
-# =============================================================================
-
-xarm7_servo = xarm_driver(
- ip="192.168.1.210",
- dof=7, # XArm7
- has_gripper=False,
- has_force_torque=False,
- control_rate=100,
- monitor_rate=10,
-).transports(
- {
- ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState),
- ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState),
- ("joint_position_command", JointCommand): LCMTransport(
- "/xarm/joint_position_command", JointCommand
- ),
- ("joint_velocity_command", JointCommand): LCMTransport(
- "/xarm/joint_velocity_command", JointCommand
- ),
- }
-)
-
-# =============================================================================
-# xArm5 Servo Control Blueprint
-# =============================================================================
-
-xarm5_servo = xarm_driver(
- ip="192.168.1.210",
- dof=5, # XArm5
- has_gripper=False,
- has_force_torque=False,
- control_rate=100,
- monitor_rate=10,
-).transports(
- {
- ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState),
- ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState),
- ("joint_position_command", JointCommand): LCMTransport(
- "/xarm/joint_position_command", JointCommand
- ),
- ("joint_velocity_command", JointCommand): LCMTransport(
- "/xarm/joint_velocity_command", JointCommand
- ),
- }
-)
-
-# =============================================================================
-# xArm Trajectory Control Blueprint (Driver + Trajectory Controller)
-# =============================================================================
-# Combines XArmDriver with JointTrajectoryController for trajectory execution.
-# The controller receives JointTrajectory messages and executes them at 100Hz.
-# =============================================================================
-
-xarm_trajectory = autoconnect(
- xarm_driver(
- ip="192.168.1.210",
- dof=6, # XArm6
- has_gripper=False,
- has_force_torque=False,
- control_rate=500,
- monitor_rate=10,
- ),
- joint_trajectory_controller(
- control_frequency=100.0,
- ),
-).transports(
- {
- # Shared topics between driver and controller
- ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState),
- ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState),
- ("joint_position_command", JointCommand): LCMTransport(
- "/xarm/joint_position_command", JointCommand
- ),
- # Trajectory input topic
- ("trajectory", JointTrajectory): LCMTransport("/trajectory", JointTrajectory),
- }
-)
-
-# =============================================================================
-# xArm7 Trajectory Control Blueprint
-# =============================================================================
-
-xarm7_trajectory = autoconnect(
- xarm_driver(
- ip="192.168.1.210",
- dof=7, # XArm7
- has_gripper=False,
- has_force_torque=False,
- control_rate=100,
- monitor_rate=10,
- ),
- joint_trajectory_controller(
- control_frequency=100.0,
- ),
-).transports(
- {
- ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState),
- ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState),
- ("joint_position_command", JointCommand): LCMTransport(
- "/xarm/joint_position_command", JointCommand
- ),
- ("trajectory", JointTrajectory): LCMTransport("/trajectory", JointTrajectory),
- }
-)
-
-# =============================================================================
-# xArm Cartesian Control Blueprint (Driver + Controller)
-# =============================================================================
-# Combines XArmDriver with CartesianMotionController for Cartesian space control.
-# The controller receives target_pose and converts to joint commands via IK.
-# =============================================================================
-
-xarm_cartesian = autoconnect(
- xarm_driver(
- ip="192.168.1.210",
- dof=6, # XArm6
- has_gripper=False,
- has_force_torque=False,
- control_rate=100,
- monitor_rate=10,
- ),
- cartesian_motion_controller(
- control_frequency=20.0,
- position_kp=5.0,
- position_ki=0.0,
- position_kd=0.1,
- max_linear_velocity=0.2,
- max_angular_velocity=1.0,
- ),
-).transports(
- {
- # Shared topics between driver and controller
- ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState),
- ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState),
- ("joint_position_command", JointCommand): LCMTransport(
- "/xarm/joint_position_command", JointCommand
- ),
- # Controller-specific topics
- ("target_pose", PoseStamped): LCMTransport("/target_pose", PoseStamped),
- ("current_pose", PoseStamped): LCMTransport("/xarm/current_pose", PoseStamped),
- }
-)
-
-
-__all__ = [
- "xarm5_servo",
- "xarm7_servo",
- "xarm7_trajectory",
- "xarm_cartesian",
- "xarm_servo",
- "xarm_trajectory",
-]
diff --git a/dimos/hardware/manipulators/xarm/xarm_driver.py b/dimos/hardware/manipulators/xarm/xarm_driver.py
deleted file mode 100644
index f6d950938c..0000000000
--- a/dimos/hardware/manipulators/xarm/xarm_driver.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""XArm driver using the generalized component-based architecture."""
-
-import logging
-from typing import Any
-
-from dimos.hardware.manipulators.base import (
- BaseManipulatorDriver,
- StandardMotionComponent,
- StandardServoComponent,
- StandardStatusComponent,
-)
-
-from .xarm_wrapper import XArmSDKWrapper
-
-logger = logging.getLogger(__name__)
-
-
-class XArmDriver(BaseManipulatorDriver):
- """XArm driver using component-based architecture.
-
- This driver supports XArm5, XArm6, and XArm7 models.
- All the complex logic is handled by the base class and standard components.
- This file just assembles the pieces.
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the XArm driver.
-
- Args:
- **kwargs: Arguments for Module initialization.
- Driver configuration can be passed via 'config' keyword arg:
- - ip: IP address of the XArm controller
- - dof: Degrees of freedom (5, 6, or 7)
- - has_gripper: Whether gripper is attached
- - has_force_torque: Whether F/T sensor is attached
- """
- # Extract driver-specific config from kwargs
- config: dict[str, Any] = kwargs.pop("config", {})
-
- # Extract driver-specific params that might be passed directly
- driver_params = [
- "ip",
- "dof",
- "has_gripper",
- "has_force_torque",
- "control_rate",
- "monitor_rate",
- ]
- for param in driver_params:
- if param in kwargs:
- config[param] = kwargs.pop(param)
-
- logger.info(f"Initializing XArmDriver with config: {config}")
-
- # Create SDK wrapper
- sdk = XArmSDKWrapper()
-
- # Create standard components
- components = [
- StandardMotionComponent(sdk),
- StandardServoComponent(sdk),
- StandardStatusComponent(sdk),
- ]
-
- # Optional: Add gripper component if configured
- # if config.get('has_gripper', False):
- # from dimos.hardware.manipulators.base.components import StandardGripperComponent
- # components.append(StandardGripperComponent(sdk))
-
- # Optional: Add force/torque component if configured
- # if config.get('has_force_torque', False):
- # from dimos.hardware.manipulators.base.components import StandardForceTorqueComponent
- # components.append(StandardForceTorqueComponent(sdk))
-
- # Remove any kwargs that would conflict with explicit arguments
- kwargs.pop("sdk", None)
- kwargs.pop("components", None)
- kwargs.pop("name", None)
-
- # Initialize base driver with SDK and components
- super().__init__(sdk=sdk, components=components, config=config, name="XArmDriver", **kwargs)
-
- logger.info("XArmDriver initialized successfully")
-
-
-# Blueprint configuration for the driver
-def get_blueprint() -> dict[str, Any]:
- """Get the blueprint configuration for the XArm driver.
-
- Returns:
- Dictionary with blueprint configuration
- """
- return {
- "name": "XArmDriver",
- "class": XArmDriver,
- "config": {
- "ip": "192.168.1.210", # Default IP
- "dof": 7, # Default to 7-DOF
- "has_gripper": False,
- "has_force_torque": False,
- "control_rate": 100, # Hz - control loop + joint feedback
- "monitor_rate": 10, # Hz - robot state monitoring
- },
- "inputs": {
- "joint_position_command": "JointCommand",
- "joint_velocity_command": "JointCommand",
- },
- "outputs": {
- "joint_state": "JointState",
- "robot_state": "RobotState",
- },
- "rpc_methods": [
- # Motion control
- "move_joint",
- "move_joint_velocity",
- "move_joint_effort",
- "stop_motion",
- "get_joint_state",
- "get_joint_limits",
- "get_velocity_limits",
- "set_velocity_scale",
- "set_acceleration_scale",
- "move_cartesian",
- "get_cartesian_state",
- "execute_trajectory",
- "stop_trajectory",
- # Servo control
- "enable_servo",
- "disable_servo",
- "toggle_servo",
- "get_servo_state",
- "emergency_stop",
- "reset_emergency_stop",
- "set_control_mode",
- "get_control_mode",
- "clear_errors",
- "reset_fault",
- "home_robot",
- "brake_release",
- "brake_engage",
- # Status monitoring
- "get_robot_state",
- "get_system_info",
- "get_capabilities",
- "get_error_state",
- "get_health_metrics",
- "get_statistics",
- "check_connection",
- "get_force_torque",
- "zero_force_torque",
- "get_digital_inputs",
- "set_digital_outputs",
- "get_analog_inputs",
- "get_gripper_state",
- ],
- }
-
-
-# Expose blueprint for declarative composition (compatible with dimos framework)
-xarm_driver = XArmDriver.blueprint
diff --git a/dimos/hardware/manipulators/xarm/xarm_wrapper.py b/dimos/hardware/manipulators/xarm/xarm_wrapper.py
deleted file mode 100644
index a743c0e3c7..0000000000
--- a/dimos/hardware/manipulators/xarm/xarm_wrapper.py
+++ /dev/null
@@ -1,564 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""XArm SDK wrapper implementation."""
-
-import logging
-import math
-from typing import Any
-
-from ..base.sdk_interface import BaseManipulatorSDK, ManipulatorInfo
-
-
-class XArmSDKWrapper(BaseManipulatorSDK):
- """SDK wrapper for XArm manipulators.
-
- This wrapper translates XArm's native SDK (which uses degrees and mm)
- to our standard interface (radians and meters).
- """
-
- def __init__(self) -> None:
- """Initialize the XArm SDK wrapper."""
- self.logger = logging.getLogger(self.__class__.__name__)
- self.native_sdk: Any = None
- self.dof = 7 # Default, will be updated on connect
- self._connected = False
-
- # ============= Connection Management =============
-
- def connect(self, config: dict[str, Any]) -> bool:
- """Connect to XArm controller.
-
- Args:
- config: Configuration with 'ip' and optionally 'dof' (5, 6, or 7)
-
- Returns:
- True if connection successful
- """
- try:
- from xarm import XArmAPI
-
- ip = config.get("ip", "192.168.1.100")
- self.dof = config.get("dof", 7)
-
- self.logger.info(f"Connecting to XArm at {ip} (DOF: {self.dof})...")
-
- # Create XArm API instance
- # XArm SDK uses degrees by default, we'll convert to radians
- self.native_sdk = XArmAPI(ip, is_radian=False)
-
- # Check connection
- if self.native_sdk.connected:
- # Initialize XArm
- self.native_sdk.motion_enable(True)
- self.native_sdk.set_mode(1) # Servo mode for high-frequency control
- self.native_sdk.set_state(0) # Ready state
-
- self._connected = True
- self.logger.info(
- f"Successfully connected to XArm (version: {self.native_sdk.version})"
- )
- return True
- else:
- self.logger.error("Failed to connect to XArm")
- return False
-
- except ImportError:
- self.logger.error("XArm SDK not installed. Please install: pip install xArm-Python-SDK")
- return False
- except Exception as e:
- self.logger.error(f"Connection failed: {e}")
- return False
-
- def disconnect(self) -> None:
- """Disconnect from XArm controller."""
- if self.native_sdk:
- try:
- self.native_sdk.disconnect()
- self._connected = False
- self.logger.info("Disconnected from XArm")
- except:
- pass
- finally:
- self.native_sdk = None
-
- def is_connected(self) -> bool:
- """Check if connected to XArm.
-
- Returns:
- True if connected
- """
- return self._connected and self.native_sdk and self.native_sdk.connected
-
- # ============= Joint State Query =============
-
- def get_joint_positions(self) -> list[float]:
- """Get current joint positions.
-
- Returns:
- Joint positions in RADIANS
- """
- code, angles = self.native_sdk.get_servo_angle()
- if code != 0:
- raise RuntimeError(f"XArm error getting positions: {code}")
-
- # Convert degrees to radians
- positions = [math.radians(angle) for angle in angles[: self.dof]]
- return positions
-
- def get_joint_velocities(self) -> list[float]:
- """Get current joint velocities.
-
- Returns:
- Joint velocities in RAD/S
- """
- # XArm doesn't directly provide velocities in older versions
- # Try to get from realtime data if available
- if hasattr(self.native_sdk, "get_joint_speeds"):
- code, speeds = self.native_sdk.get_joint_speeds()
- if code == 0:
- # Convert deg/s to rad/s
- return [math.radians(speed) for speed in speeds[: self.dof]]
-
- # Return zeros if not available
- return [0.0] * self.dof
-
- def get_joint_efforts(self) -> list[float]:
- """Get current joint efforts/torques.
-
- Returns:
- Joint efforts in Nm
- """
- # Try to get joint torques
- if hasattr(self.native_sdk, "get_joint_torques"):
- code, torques = self.native_sdk.get_joint_torques()
- if code == 0:
- return list(torques[: self.dof])
-
- # Return zeros if not available
- return [0.0] * self.dof
-
- # ============= Joint Motion Control =============
-
- def set_joint_positions(
- self,
- positions: list[float],
- _velocity: float = 1.0,
- _acceleration: float = 1.0,
- _wait: bool = False,
- ) -> bool:
- """Move joints to target positions using servo mode.
-
- Args:
- positions: Target positions in RADIANS
- _velocity: UNUSED in servo mode (kept for interface compatibility)
- _acceleration: UNUSED in servo mode (kept for interface compatibility)
- _wait: UNUSED in servo mode (kept for interface compatibility)
-
- Returns:
- True if command accepted
- """
- # Convert radians to degrees
- degrees = [math.degrees(pos) for pos in positions]
-
- # Use set_servo_angle_j for high-frequency servo control (100Hz+)
- # This sends immediate position commands without trajectory planning
- # Requires mode 1 (servo mode) and executes only the last instruction
- code = self.native_sdk.set_servo_angle_j(degrees, speed=100, mvacc=500, wait=False)
-
- return bool(code == 0)
-
- def set_joint_velocities(self, velocities: list[float]) -> bool:
- """Set joint velocity targets.
-
- Args:
- velocities: Target velocities in RAD/S
-
- Returns:
- True if command accepted
- """
- # Check if velocity control is supported
- if not hasattr(self.native_sdk, "vc_set_joint_velocity"):
- self.logger.warning("Velocity control not supported in this XArm version")
- return False
-
- # Convert rad/s to deg/s
- deg_velocities = [math.degrees(vel) for vel in velocities]
-
- # Set to velocity control mode if needed
- if self.native_sdk.mode != 4:
- self.native_sdk.set_mode(4) # Joint velocity mode
-
- # Send velocity command
- code = self.native_sdk.vc_set_joint_velocity(deg_velocities)
- return bool(code == 0)
-
- def set_joint_efforts(self, efforts: list[float]) -> bool:
- """Set joint effort/torque targets.
-
- Args:
- efforts: Target efforts in Nm
-
- Returns:
- True if command accepted
- """
- # Check if torque control is supported
- if not hasattr(self.native_sdk, "set_joint_torque"):
- self.logger.warning("Torque control not supported in this XArm version")
- return False
-
- # Send torque command
- code = self.native_sdk.set_joint_torque(efforts)
- return bool(code == 0)
-
- def stop_motion(self) -> bool:
- """Stop all ongoing motion.
-
- Returns:
- True if stop successful
- """
- # XArm emergency stop
- code = self.native_sdk.emergency_stop()
-
- # Re-enable after stop
- if code == 0:
- self.native_sdk.set_state(0) # Clear stop state
- self.native_sdk.motion_enable(True)
-
- return bool(code == 0)
-
- # ============= Servo Control =============
-
- def enable_servos(self) -> bool:
- """Enable motor control.
-
- Returns:
- True if servos enabled
- """
- code1 = self.native_sdk.motion_enable(True)
- code2 = self.native_sdk.set_state(0) # Ready state
- code3 = self.native_sdk.set_mode(1) # Servo mode
- return bool(code1 == 0 and code2 == 0 and code3 == 0)
-
- def disable_servos(self) -> bool:
- """Disable motor control.
-
- Returns:
- True if servos disabled
- """
- code = self.native_sdk.motion_enable(False)
- return bool(code == 0)
-
- def are_servos_enabled(self) -> bool:
- """Check if servos are enabled.
-
- Returns:
- True if enabled
- """
- # Check motor state
- return bool(self.native_sdk.mode == 1 and self.native_sdk.mode != 4)
-
- # ============= System State =============
-
- def get_robot_state(self) -> dict[str, Any]:
- """Get current robot state.
-
- Returns:
- State dictionary
- """
- return {
- "state": self.native_sdk.state, # 0=ready, 1=pause, 2=stop, 3=running, 4=error
- "mode": self.native_sdk.mode, # 0=position, 1=servo, 4=joint_vel, 5=cart_vel
- "error_code": self.native_sdk.error_code,
- "warn_code": self.native_sdk.warn_code,
- "is_moving": self.native_sdk.state == 3,
- "cmd_num": self.native_sdk.cmd_num,
- }
-
- def get_error_code(self) -> int:
- """Get current error code.
-
- Returns:
- Error code (0 = no error)
- """
- return int(self.native_sdk.error_code)
-
- def get_error_message(self) -> str:
- """Get human-readable error message.
-
- Returns:
- Error message string
- """
- if self.native_sdk.error_code == 0:
- return ""
-
- # XArm error codes (partial list)
- error_map = {
- 1: "Emergency stop button pressed",
- 2: "Joint limit exceeded",
- 3: "Command reply timeout",
- 4: "Power supply error",
- 5: "Motor overheated",
- 6: "Motor driver error",
- 7: "Other error",
- 10: "Servo error",
- 11: "Joint collision",
- 12: "Tool IO error",
- 13: "Tool communication error",
- 14: "Kinematic error",
- 15: "Self collision",
- 16: "Joint overheated",
- 17: "Planning error",
- 19: "Force control error",
- 20: "Joint current overlimit",
- 21: "TCP command overlimit",
- 22: "Overspeed",
- }
-
- return error_map.get(
- self.native_sdk.error_code, f"Unknown error {self.native_sdk.error_code}"
- )
-
- def clear_errors(self) -> bool:
- """Clear error states.
-
- Returns:
- True if errors cleared
- """
- code = self.native_sdk.clean_error()
- if code == 0:
- # Reset to ready state
- self.native_sdk.set_state(0)
- return bool(code == 0)
-
- def emergency_stop(self) -> bool:
- """Execute emergency stop.
-
- Returns:
- True if e-stop executed
- """
- code = self.native_sdk.emergency_stop()
- return bool(code == 0)
-
- # ============= Information =============
-
- def get_info(self) -> ManipulatorInfo:
- """Get manipulator information.
-
- Returns:
- ManipulatorInfo object
- """
- return ManipulatorInfo(
- vendor="UFACTORY",
- model=f"xArm{self.dof}",
- dof=self.dof,
- firmware_version=self.native_sdk.version if self.native_sdk else None,
- serial_number=self.native_sdk.get_servo_version()[1][0] if self.native_sdk else None,
- )
-
- def get_joint_limits(self) -> tuple[list[float], list[float]]:
- """Get joint position limits.
-
- Returns:
- Tuple of (lower_limits, upper_limits) in RADIANS
- """
- # XArm joint limits in degrees (approximate, varies by model)
- if self.dof == 7:
- lower_deg = [-360, -118, -360, -233, -360, -97, -360]
- upper_deg = [360, 118, 360, 11, 360, 180, 360]
- elif self.dof == 6:
- lower_deg = [-360, -118, -225, -11, -360, -97]
- upper_deg = [360, 118, 11, 225, 360, 180]
- else: # 5 DOF
- lower_deg = [-360, -118, -225, -97, -360]
- upper_deg = [360, 118, 11, 180, 360]
-
- # Convert to radians
- lower_rad = [math.radians(d) for d in lower_deg[: self.dof]]
- upper_rad = [math.radians(d) for d in upper_deg[: self.dof]]
-
- return (lower_rad, upper_rad)
-
- def get_velocity_limits(self) -> list[float]:
- """Get joint velocity limits.
-
- Returns:
- Maximum velocities in RAD/S
- """
- # XArm max velocities in deg/s (default)
- max_vel_deg = 180.0
-
- # Convert to rad/s
- max_vel_rad = math.radians(max_vel_deg)
- return [max_vel_rad] * self.dof
-
- def get_acceleration_limits(self) -> list[float]:
- """Get joint acceleration limits.
-
- Returns:
- Maximum accelerations in RAD/S²
- """
- # XArm max acceleration in deg/s² (default)
- max_acc_deg = 1145.0
-
- # Convert to rad/s²
- max_acc_rad = math.radians(max_acc_deg)
- return [max_acc_rad] * self.dof
-
- # ============= Optional Methods =============
-
- def get_cartesian_position(self) -> dict[str, float] | None:
- """Get current end-effector pose.
-
- Returns:
- Pose dict or None if not supported
- """
- code, pose = self.native_sdk.get_position()
- if code != 0:
- return None
-
- # XArm returns [x, y, z (mm), roll, pitch, yaw (degrees)]
- return {
- "x": pose[0] / 1000.0, # mm to meters
- "y": pose[1] / 1000.0,
- "z": pose[2] / 1000.0,
- "roll": math.radians(pose[3]),
- "pitch": math.radians(pose[4]),
- "yaw": math.radians(pose[5]),
- }
-
- def set_cartesian_position(
- self,
- pose: dict[str, float],
- velocity: float = 1.0,
- acceleration: float = 1.0,
- wait: bool = False,
- ) -> bool:
- """Move end-effector to target pose.
-
- Args:
- pose: Target pose dict
- velocity: Max velocity fraction (0-1)
- acceleration: Max acceleration fraction (0-1)
- wait: Block until complete
-
- Returns:
- True if command accepted
- """
- # Convert to XArm format
- xarm_pose = [
- pose["x"] * 1000.0, # meters to mm
- pose["y"] * 1000.0,
- pose["z"] * 1000.0,
- math.degrees(pose["roll"]),
- math.degrees(pose["pitch"]),
- math.degrees(pose["yaw"]),
- ]
-
- # XArm max Cartesian speed (default 500 mm/s)
- max_speed = 500.0
- speed = max_speed * velocity
-
- # XArm max Cartesian acceleration (default 2000 mm/s²)
- max_acc = 2000.0
- acc = max_acc * acceleration
-
- code = self.native_sdk.set_position(xarm_pose, radius=-1, speed=speed, mvacc=acc, wait=wait)
-
- return bool(code == 0)
-
- def get_force_torque(self) -> list[float] | None:
- """Get F/T sensor reading.
-
- Returns:
- [fx, fy, fz, tx, ty, tz] or None
- """
- if hasattr(self.native_sdk, "get_ft_sensor_data"):
- code, ft_data = self.native_sdk.get_ft_sensor_data()
- if code == 0:
- return list(ft_data)
- return None
-
- def zero_force_torque(self) -> bool:
- """Zero the F/T sensor.
-
- Returns:
- True if successful
- """
- if hasattr(self.native_sdk, "set_ft_sensor_zero"):
- code = self.native_sdk.set_ft_sensor_zero()
- return bool(code == 0)
- return False
-
- def get_gripper_position(self) -> float | None:
- """Get gripper position.
-
- Returns:
- Position in meters or None
- """
- if hasattr(self.native_sdk, "get_gripper_position"):
- code, pos = self.native_sdk.get_gripper_position()
- if code == 0:
- # Convert mm to meters
- return float(pos / 1000.0)
- return None
-
- def set_gripper_position(self, position: float, force: float = 1.0) -> bool:
- """Set gripper position.
-
- Args:
- position: Target position in meters
- force: Force fraction (0-1)
-
- Returns:
- True if successful
- """
- if hasattr(self.native_sdk, "set_gripper_position"):
- # Convert meters to mm
- pos_mm = position * 1000.0
- code = self.native_sdk.set_gripper_position(pos_mm, wait=False)
- return bool(code == 0)
- return False
-
- def set_control_mode(self, mode: str) -> bool:
- """Set control mode.
-
- Args:
- mode: 'position', 'velocity', 'torque', or 'impedance'
-
- Returns:
- True if successful
- """
- mode_map = {
- "position": 0,
- "velocity": 4, # Joint velocity mode
- "servo": 1, # Servo mode (for torque control)
- "impedance": 0, # Not directly supported, use position
- }
-
- if mode not in mode_map:
- return False
-
- code = self.native_sdk.set_mode(mode_map[mode])
- return bool(code == 0)
-
- def get_control_mode(self) -> str | None:
- """Get current control mode.
-
- Returns:
- Mode string or None
- """
- mode_map = {0: "position", 1: "servo", 4: "velocity", 5: "cartesian_velocity"}
-
- return mode_map.get(self.native_sdk.mode, "unknown")
diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py
index 10c541723a..6f51febfef 100644
--- a/dimos/hardware/sensors/camera/module.py
+++ b/dimos/hardware/sensors/camera/module.py
@@ -22,6 +22,7 @@
from dimos.agents import Output, Reducer, Stream, skill
from dimos.core import Module, ModuleConfig, Out, rpc
+from dimos.core.blueprints import autoconnect
from dimos.hardware.sensors.camera.spec import CameraHardware
from dimos.hardware.sensors.camera.webcam import Webcam
from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3
@@ -114,4 +115,9 @@ def stop(self) -> None:
camera_module = CameraModule.blueprint
+demo_camera = autoconnect(
+ camera_module(),
+)
+
+
__all__ = ["CameraModule", "camera_module"]
diff --git a/dimos/hardware/sensors/camera/realsense/README.md b/dimos/hardware/sensors/camera/realsense/README.md
new file mode 100644
index 0000000000..665833047e
--- /dev/null
+++ b/dimos/hardware/sensors/camera/realsense/README.md
@@ -0,0 +1,9 @@
+# RealSense SDK Install
+
+1) Install the Intel RealSense SDK:
+ - https://github.com/IntelRealSense/librealsense
+
+2) Install the Python bindings:
+ ```bash
+ pip install pyrealsense2
+ ```
diff --git a/dimos/manipulation/control/dual_trajectory_setter.py b/dimos/manipulation/control/dual_trajectory_setter.py
new file mode 100644
index 0000000000..4b54f0e3e5
--- /dev/null
+++ b/dimos/manipulation/control/dual_trajectory_setter.py
@@ -0,0 +1,540 @@
+#!/usr/bin/env python3
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Dual-Arm Interactive Trajectory Publisher.
+
+Interactive terminal UI for creating joint trajectories for two arms independently.
+Supports running trajectories on left arm, right arm, or both simultaneously.
+
+Workflow:
+1. Add waypoints to left or right arm (or both)
+2. Generator applies trapezoidal velocity profiles
+3. Preview the generated trajectories
+4. Run on left, right, or both arms
+
+Use with xarm-dual-trajectory blueprint running in another terminal.
+"""
+
+from dataclasses import dataclass
+import math
+import sys
+import time
+
+from dimos import core
+from dimos.manipulation.planning import JointTrajectoryGenerator
+from dimos.msgs.sensor_msgs import JointState
+from dimos.msgs.trajectory_msgs import JointTrajectory
+
+
+@dataclass
+class ArmState:
+ """State for a single arm."""
+
+ name: str
+ num_joints: int | None = None
+ latest_joint_state: JointState | None = None
+ generator: JointTrajectoryGenerator | None = None
+ waypoints: list[list[float]] | None = None
+ generated_trajectory: JointTrajectory | None = None
+
+ def __post_init__(self) -> None:
+ self.waypoints = []
+
+
+class DualTrajectorySetter:
+ """
+ Creates and publishes JointTrajectory for dual-arm setups.
+
+ Manages two arms independently with separate waypoints and trajectories.
+ Supports running trajectories on one or both arms.
+ """
+
+ def __init__(
+ self,
+ left_joint_topic: str = "/xarm/left/joint_states",
+ right_joint_topic: str = "/xarm/right/joint_states",
+ left_trajectory_topic: str = "/xarm/left/trajectory",
+ right_trajectory_topic: str = "/xarm/right/trajectory",
+ ):
+ """
+ Initialize the dual trajectory setter.
+
+ Args:
+ left_joint_topic: Topic for left arm joint states
+ right_joint_topic: Topic for right arm joint states
+ left_trajectory_topic: Topic to publish left arm trajectories
+ right_trajectory_topic: Topic to publish right arm trajectories
+ """
+ # Arm states
+ self.left = ArmState(name="left")
+ self.right = ArmState(name="right")
+
+ # Publishers for trajectories
+ self.left_trajectory_pub: core.LCMTransport[JointTrajectory] = core.LCMTransport(
+ left_trajectory_topic, JointTrajectory
+ )
+ self.right_trajectory_pub: core.LCMTransport[JointTrajectory] = core.LCMTransport(
+ right_trajectory_topic, JointTrajectory
+ )
+
+ # Subscribers for joint states
+ self.left_joint_sub: core.LCMTransport[JointState] = core.LCMTransport(
+ left_joint_topic, JointState
+ )
+ self.right_joint_sub: core.LCMTransport[JointState] = core.LCMTransport(
+ right_joint_topic, JointState
+ )
+
+ print("DualTrajectorySetter initialized")
+ print(f" Left arm: {left_joint_topic} -> {left_trajectory_topic}")
+ print(f" Right arm: {right_joint_topic} -> {right_trajectory_topic}")
+
+ def start(self) -> bool:
+ """Start subscribing to joint states."""
+ self.left_joint_sub.subscribe(self._on_left_joint_state)
+ self.right_joint_sub.subscribe(self._on_right_joint_state)
+ print(" Waiting for joint states...")
+
+ # Wait for both arms
+ left_ready = False
+ right_ready = False
+
+ for _ in range(50): # 5 second timeout
+ if not left_ready and self.left.latest_joint_state is not None:
+ self.left.num_joints = len(self.left.latest_joint_state.position)
+ self.left.generator = JointTrajectoryGenerator(
+ num_joints=self.left.num_joints,
+ max_velocity=1.0,
+ max_acceleration=2.0,
+ points_per_segment=50,
+ )
+ print(f" Left arm ready ({self.left.num_joints} joints)")
+ left_ready = True
+
+ if not right_ready and self.right.latest_joint_state is not None:
+ self.right.num_joints = len(self.right.latest_joint_state.position)
+ self.right.generator = JointTrajectoryGenerator(
+ num_joints=self.right.num_joints,
+ max_velocity=1.0,
+ max_acceleration=2.0,
+ points_per_segment=50,
+ )
+ print(f" Right arm ready ({self.right.num_joints} joints)")
+ right_ready = True
+
+ if left_ready and right_ready:
+ return True
+
+ time.sleep(0.1)
+
+ if not left_ready:
+ print(" Warning: Left arm not responding")
+ if not right_ready:
+ print(" Warning: Right arm not responding")
+
+ return left_ready or right_ready
+
+ def _on_left_joint_state(self, msg: JointState) -> None:
+ """Callback for left arm joint state."""
+ self.left.latest_joint_state = msg
+
+ def _on_right_joint_state(self, msg: JointState) -> None:
+ """Callback for right arm joint state."""
+ self.right.latest_joint_state = msg
+
+ def get_current_joints(self, arm: ArmState) -> list[float] | None:
+ """Get current joint positions for an arm."""
+ if arm.latest_joint_state is None or arm.num_joints is None:
+ return None
+ return list(arm.latest_joint_state.position[: arm.num_joints])
+
+ def generate_trajectory(self, arm: ArmState) -> JointTrajectory | None:
+ """Generate trajectory for an arm from its waypoints."""
+ if arm.generator is None or not arm.waypoints or len(arm.waypoints) < 2:
+ return None
+ return arm.generator.generate(arm.waypoints)
+
+ def publish_trajectory(self, arm: ArmState, trajectory: JointTrajectory) -> None:
+ """Publish trajectory to an arm."""
+ if arm.name == "left":
+ self.left_trajectory_pub.broadcast(None, trajectory)
+ else:
+ self.right_trajectory_pub.broadcast(None, trajectory)
+ print(
+ f" Published to {arm.name}: {len(trajectory.points)} points, "
+ f"duration={trajectory.duration:.2f}s"
+ )
+
+
+def parse_joint_input(line: str, num_joints: int) -> list[float] | None:
+ """Parse joint positions from user input (degrees by default, 'r' suffix for radians)."""
+ parts = line.strip().split()
+ if len(parts) != num_joints:
+ return None
+
+ positions = []
+ for part in parts:
+ try:
+ if part.endswith("r"):
+ positions.append(float(part[:-1]))
+ else:
+ positions.append(math.radians(float(part)))
+ except ValueError:
+ return None
+
+ return positions
+
+
+def preview_waypoints(arm: ArmState) -> None:
+ """Show waypoints for an arm."""
+ if not arm.waypoints or arm.num_joints is None:
+ print(f" {arm.name.upper()}: No waypoints")
+ return
+
+ joint_headers = " ".join([f"{'J' + str(i + 1):>7}" for i in range(arm.num_joints)])
+ line_width = 6 + 3 + arm.num_joints * 8 + 10
+
+ print(f"\n{arm.name.upper()} Waypoints ({len(arm.waypoints)}):")
+ print("-" * line_width)
+ print(f" # | {joint_headers} (degrees)")
+ print("-" * line_width)
+ for i, joints in enumerate(arm.waypoints):
+ deg = [f"{math.degrees(j):7.1f}" for j in joints]
+ print(f" {i + 1:2} | {' '.join(deg)}")
+ print("-" * line_width)
+
+
+def preview_trajectory(arm: ArmState) -> None:
+ """Show generated trajectory for an arm."""
+ if arm.generated_trajectory is None or arm.num_joints is None:
+ print(f" {arm.name.upper()}: No trajectory")
+ return
+
+ traj = arm.generated_trajectory
+ joint_headers = " ".join([f"{'J' + str(i + 1):>7}" for i in range(arm.num_joints)])
+ line_width = 9 + 3 + arm.num_joints * 8 + 10
+
+ print(f"\n{'=' * line_width}")
+ print(f"{arm.name.upper()} TRAJECTORY")
+ print(f"{'=' * line_width}")
+ print(f"Duration: {traj.duration:.3f}s | Points: {len(traj.points)}")
+ print("-" * line_width)
+ print(f"{'Time':>6} | {joint_headers} (degrees)")
+ print("-" * line_width)
+
+ num_samples = min(10, max(len(traj.points) // 10, 5))
+ for i in range(num_samples + 1):
+ t = (i / num_samples) * traj.duration
+ q_ref, _ = traj.sample(t)
+ q_deg = [f"{math.degrees(q):7.1f}" for q in q_ref]
+ print(f"{t:6.2f} | {' '.join(q_deg)}")
+
+ print("-" * line_width)
+
+
+def interactive_mode(setter: DualTrajectorySetter) -> None:
+ """Interactive mode for creating dual-arm trajectories."""
+ left = setter.left
+ right = setter.right
+
+ print("\n" + "=" * 80)
+ print("Dual-Arm Interactive Trajectory Setter")
+ print("=" * 80)
+
+ if left.num_joints:
+ print(f" Left arm: {left.num_joints} joints")
+ else:
+ print(" Left arm: NOT CONNECTED")
+
+ if right.num_joints:
+ print(f" Right arm: {right.num_joints} joints")
+ else:
+ print(" Right arm: NOT CONNECTED")
+
+ print("\nCommands:")
+ print(" left add ... - Add waypoint to left arm (degrees)")
+ print(" right add ... - Add waypoint to right arm (degrees)")
+ print(" left here - Add current position as waypoint (left)")
+ print(" right here - Add current position as waypoint (right)")
+ print(" left current - Show current left arm joints")
+ print(" right current - Show current right arm joints")
+ print(" left list - List left arm waypoints")
+ print(" right list - List right arm waypoints")
+ print(" left delete - Delete waypoint n from left")
+ print(" right delete - Delete waypoint n from right")
+ print(" left clear - Clear left arm waypoints")
+ print(" right clear - Clear right arm waypoints")
+ print(" preview - Preview both trajectories")
+ print(" run left - Run trajectory on left arm only")
+ print(" run right - Run trajectory on right arm only")
+ print(" run both - Run trajectories on both arms")
+ print(" vel - Set max velocity (rad/s)")
+ print(" quit - Exit")
+ print("=" * 80)
+
+ try:
+ while True:
+ left_wp = len(left.waypoints) if left.waypoints else 0
+ right_wp = len(right.waypoints) if right.waypoints else 0
+ prompt = f"[L:{left_wp} R:{right_wp}] > "
+ line = input(prompt).strip()
+
+ if not line:
+ continue
+
+ parts = line.split()
+ cmd = parts[0].lower()
+
+ # Determine which arm (if applicable)
+ arm: ArmState | None = None
+ if cmd in ("left", "l"):
+ arm = left
+ parts = parts[1:] # Remove arm selector
+ cmd = parts[0].lower() if parts else ""
+ elif cmd in ("right", "r"):
+ arm = right
+ parts = parts[1:]
+ cmd = parts[0].lower() if parts else ""
+
+ # ARM-SPECIFIC COMMANDS
+ if arm is not None:
+ if arm.num_joints is None:
+ print(f" {arm.name.upper()} arm not connected")
+ continue
+
+ # ADD waypoint
+ if cmd == "add" and len(parts) >= arm.num_joints + 1:
+ joints = parse_joint_input(
+ " ".join(parts[1 : arm.num_joints + 1]), arm.num_joints
+ )
+ if joints:
+ arm.waypoints.append(joints) # type: ignore[union-attr]
+ arm.generated_trajectory = None
+ deg = [f"{math.degrees(j):.1f}" for j in joints]
+ print(
+ f" {arm.name.upper()} waypoint {len(arm.waypoints)}: [{', '.join(deg)}] deg" # type: ignore[arg-type]
+ )
+ else:
+ print(f" Invalid values (need {arm.num_joints} in degrees)")
+
+ # HERE - add current position
+ elif cmd == "here":
+ joints = setter.get_current_joints(arm)
+ if joints:
+ arm.waypoints.append(joints) # type: ignore[union-attr]
+ arm.generated_trajectory = None
+ deg = [f"{math.degrees(j):.1f}" for j in joints]
+ print(
+ f" {arm.name.upper()} waypoint {len(arm.waypoints)}: [{', '.join(deg)}] deg" # type: ignore[arg-type]
+ )
+ else:
+ print(" No joint state available")
+
+ # CURRENT
+ elif cmd == "current":
+ joints = setter.get_current_joints(arm)
+ if joints:
+ deg = [f"{math.degrees(j):.1f}" for j in joints]
+ print(f" {arm.name.upper()}: [{', '.join(deg)}] deg")
+ else:
+ print(" No joint state available")
+
+ # LIST
+ elif cmd == "list":
+ preview_waypoints(arm)
+
+ # DELETE
+ elif cmd == "delete" and len(parts) >= 2:
+ try:
+ idx = int(parts[1]) - 1
+ if arm.waypoints and 0 <= idx < len(arm.waypoints):
+ arm.waypoints.pop(idx)
+ arm.generated_trajectory = None
+ print(f" Deleted {arm.name} waypoint {idx + 1}")
+ else:
+ wp_count = len(arm.waypoints) if arm.waypoints else 0
+ print(f" Invalid index (1-{wp_count})")
+ except ValueError:
+ print(" Invalid index")
+
+ # CLEAR
+ elif cmd == "clear":
+ if arm.waypoints:
+ arm.waypoints.clear()
+ arm.generated_trajectory = None
+ print(f" {arm.name.upper()} waypoints cleared")
+
+ else:
+ print(f" Unknown command for {arm.name}: {cmd}")
+
+ # GLOBAL COMMANDS
+ elif cmd == "preview":
+ # Generate trajectories if needed
+ for a in [left, right]:
+ if a.waypoints and len(a.waypoints) >= 2:
+ try:
+ a.generated_trajectory = setter.generate_trajectory(a)
+ except Exception as e:
+ print(f" Error generating {a.name} trajectory: {e}")
+ a.generated_trajectory = None
+
+ preview_trajectory(left)
+ preview_trajectory(right)
+
+ elif cmd == "run" and len(parts) >= 2:
+ target = parts[1].lower()
+
+ # Determine which arms to run
+ arms_to_run: list[ArmState] = []
+ if target in ("left", "l"):
+ arms_to_run = [left]
+ elif target in ("right", "r"):
+ arms_to_run = [right]
+ elif target == "both":
+ arms_to_run = [left, right]
+ else:
+ print(" Usage: run left|right|both")
+ continue
+
+ # Generate trajectories if needed
+ for a in arms_to_run:
+ if not a.waypoints or len(a.waypoints) < 2:
+ print(f" {a.name.upper()}: Need at least 2 waypoints")
+ continue
+
+ if a.generated_trajectory is None:
+ try:
+ a.generated_trajectory = setter.generate_trajectory(a)
+ except Exception as e:
+ print(f" Error generating {a.name} trajectory: {e}")
+ continue
+
+ # Preview and confirm
+ valid_arms = [a for a in arms_to_run if a.generated_trajectory is not None]
+ if not valid_arms:
+ print(" No valid trajectories to run")
+ continue
+
+ for a in valid_arms:
+ preview_trajectory(a)
+
+ arm_names = ", ".join(a.name.upper() for a in valid_arms)
+ confirm = input(f"\n Run on {arm_names}? [y/N]: ").strip().lower()
+ if confirm == "y":
+ print("\n Publishing trajectories...")
+ for a in valid_arms:
+ if a.generated_trajectory:
+ setter.publish_trajectory(a, a.generated_trajectory)
+
+ elif cmd == "vel" and len(parts) >= 3:
+ arm_name = parts[1].lower()
+ target_arm: ArmState | None = (
+ left
+ if arm_name in ("left", "l")
+ else right
+ if arm_name in ("right", "r")
+ else None
+ )
+ if target_arm is None or target_arm.generator is None:
+ print(" Usage: vel left|right ")
+ continue
+ try:
+ vel = float(parts[2])
+ if vel <= 0:
+ print(" Velocity must be positive")
+ else:
+ target_arm.generator.set_limits(vel, target_arm.generator.max_acceleration)
+ target_arm.generated_trajectory = None
+ print(f" {target_arm.name.upper()} max velocity: {vel:.2f} rad/s")
+ except ValueError:
+ print(" Invalid velocity value")
+
+ elif cmd in ("quit", "exit", "q"):
+ break
+
+ else:
+ print(f" Unknown command: {cmd}")
+
+ except KeyboardInterrupt:
+ print("\n\nExiting...")
+
+
+def main() -> int:
+ """Main entry point."""
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Dual-Arm Interactive Trajectory Setter")
+ parser.add_argument(
+ "--left-joint-topic",
+ type=str,
+ default="/xarm/left/joint_states",
+ help="Left arm joint state topic",
+ )
+ parser.add_argument(
+ "--right-joint-topic",
+ type=str,
+ default="/xarm/right/joint_states",
+ help="Right arm joint state topic",
+ )
+ parser.add_argument(
+ "--left-trajectory-topic",
+ type=str,
+ default="/xarm/left/trajectory",
+ help="Left arm trajectory topic",
+ )
+ parser.add_argument(
+ "--right-trajectory-topic",
+ type=str,
+ default="/xarm/right/trajectory",
+ help="Right arm trajectory topic",
+ )
+ args = parser.parse_args()
+
+ print("\n" + "=" * 80)
+ print("Dual-Arm Trajectory Setter")
+ print("=" * 80)
+ print("\nRun 'dimos run xarm-dual-trajectory' in another terminal first!")
+ print("=" * 80)
+
+ setter = DualTrajectorySetter(
+ left_joint_topic=args.left_joint_topic,
+ right_joint_topic=args.right_joint_topic,
+ left_trajectory_topic=args.left_trajectory_topic,
+ right_trajectory_topic=args.right_trajectory_topic,
+ )
+
+ if not setter.start():
+ print("\nWarning: Could not connect to both arms")
+ response = input("Continue anyway? [y/N]: ").strip().lower()
+ if response != "y":
+ return 0
+
+ interactive_mode(setter)
+ return 0
+
+
+if __name__ == "__main__":
+ try:
+ sys.exit(main())
+ except KeyboardInterrupt:
+ print("\n\nInterrupted by user")
+ sys.exit(0)
+ except Exception as e:
+ print(f"\nError: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/dimos/manipulation/control/orchestrator_client.py b/dimos/manipulation/control/orchestrator_client.py
new file mode 100644
index 0000000000..84e85dfb3d
--- /dev/null
+++ b/dimos/manipulation/control/orchestrator_client.py
@@ -0,0 +1,696 @@
+#!/usr/bin/env python3
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Interactive client for the ControlOrchestrator.
+
+Interfaces with a running ControlOrchestrator via RPC to:
+- Query hardware and task status
+- Plan and execute trajectories on single or multiple arms
+- Monitor execution progress
+
+Usage:
+ # Terminal 1: Start the orchestrator
+ dimos run orchestrator-mock # Single arm
+ dimos run orchestrator-dual-mock # Dual arm
+
+ # Terminal 2: Run this client
+ python -m dimos.manipulation.control.orchestrator_client
+ python -m dimos.manipulation.control.orchestrator_client --task traj_left
+ python -m dimos.manipulation.control.orchestrator_client --task traj_right
+
+How it works:
+ 1. Connects to ControlOrchestrator via LCM RPC
+ 2. Queries available hardware/tasks/joints
+ 3. You add waypoints (joint positions)
+ 4. Generates trajectory with trapezoidal velocity profile
+ 5. Sends trajectory to orchestrator via execute_trajectory() RPC
+ 6. Orchestrator's tick loop executes it at 100Hz
+"""
+
+from __future__ import annotations
+
+import math
+import sys
+import time
+from typing import TYPE_CHECKING, Any
+
+from dimos.control.orchestrator import ControlOrchestrator
+from dimos.core.rpc_client import RPCClient
+from dimos.manipulation.planning import JointTrajectoryGenerator
+
+if TYPE_CHECKING:
+ from dimos.msgs.trajectory_msgs import JointTrajectory
+
+
+class OrchestratorClient:
+ """
+ RPC client for the ControlOrchestrator.
+
+ Connects to a running orchestrator and provides methods to:
+ - Query state (joints, tasks, hardware)
+ - Execute trajectories on any task
+ - Monitor progress
+
+ Example:
+ client = OrchestratorClient()
+
+ # Query state
+ print(client.list_hardware()) # ['left_arm', 'right_arm']
+ print(client.list_tasks()) # ['traj_left', 'traj_right']
+
+ # Setup for a task
+ client.select_task("traj_left")
+
+ # Get current position and create trajectory
+ current = client.get_current_positions()
+ target = [0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
+ trajectory = client.generate_trajectory([current, target])
+
+ # Execute
+ client.execute_trajectory("traj_left", trajectory)
+ """
+
+ def __init__(self) -> None:
+ """Initialize connection to orchestrator via RPC."""
+ self._rpc = RPCClient(None, ControlOrchestrator)
+
+ # Per-task state
+ self._current_task: str | None = None
+ self._task_joints: dict[str, list[str]] = {} # task_name -> joint_names
+ self._generators: dict[str, JointTrajectoryGenerator] = {} # task_name -> generator
+
+ def stop(self) -> None:
+ """Stop the RPC client."""
+ self._rpc.stop_rpc_client()
+
+ # =========================================================================
+ # Query methods (RPC calls)
+ # =========================================================================
+
+ def list_hardware(self) -> list[str]:
+ """List all hardware IDs."""
+ return self._rpc.list_hardware() or []
+
+ def list_joints(self) -> list[str]:
+ """List all joint names across all hardware."""
+ return self._rpc.list_joints() or []
+
+ def list_tasks(self) -> list[str]:
+ """List all task names."""
+ return self._rpc.list_tasks() or []
+
+ def get_active_tasks(self) -> list[str]:
+ """Get currently active task names."""
+ return self._rpc.get_active_tasks() or []
+
+ def get_joint_positions(self) -> dict[str, float]:
+ """Get current joint positions for all joints."""
+ return self._rpc.get_joint_positions() or {}
+
+ def get_trajectory_status(self, task_name: str) -> dict[str, Any]:
+ """Get status of a trajectory task."""
+ return self._rpc.get_trajectory_status(task_name) or {}
+
+ # =========================================================================
+ # Trajectory execution (RPC calls)
+ # =========================================================================
+
+ def execute_trajectory(self, task_name: str, trajectory: JointTrajectory) -> bool:
+ """Execute a trajectory on a task."""
+ return self._rpc.execute_trajectory(task_name, trajectory) or False
+
+ def cancel_trajectory(self, task_name: str) -> bool:
+ """Cancel an active trajectory."""
+ return self._rpc.cancel_trajectory(task_name) or False
+
+ # =========================================================================
+ # Task selection and setup
+ # =========================================================================
+
+ def select_task(self, task_name: str) -> bool:
+ """
+ Select a task and setup its trajectory generator.
+
+ This queries the orchestrator to find which joints the task controls,
+ then creates a trajectory generator for those joints.
+ """
+ tasks = self.list_tasks()
+ if task_name not in tasks:
+ print(f"Task '{task_name}' not found. Available: {tasks}")
+ return False
+
+ self._current_task = task_name
+
+ # Get joints for this task (infer from task name pattern)
+ # e.g., "traj_left" -> joints starting with "left_"
+ # e.g., "traj_arm" -> joints starting with "arm_"
+ all_joints = self.list_joints()
+
+ # Try to infer prefix from task name
+ if "_" in task_name:
+ prefix = task_name.split("_", 1)[1] # "traj_left" -> "left"
+ task_joints = [j for j in all_joints if j.startswith(prefix + "_")]
+ else:
+ task_joints = all_joints
+
+ if not task_joints:
+ # Fallback: use all joints
+ task_joints = all_joints
+
+ self._task_joints[task_name] = task_joints
+
+ # Create generator if not exists
+ if task_name not in self._generators:
+ self._generators[task_name] = JointTrajectoryGenerator(
+ num_joints=len(task_joints),
+ max_velocity=1.0,
+ max_acceleration=2.0,
+ points_per_segment=50,
+ )
+
+ return True
+
+ def get_task_joints(self, task_name: str | None = None) -> list[str]:
+ """Get joint names for a task."""
+ task = task_name or self._current_task
+ if task is None:
+ return []
+ return self._task_joints.get(task, [])
+
+ def get_current_positions(self, task_name: str | None = None) -> list[float] | None:
+ """Get current joint positions for a task as a list."""
+ task = task_name or self._current_task
+ if task is None:
+ return None
+
+ joints = self._task_joints.get(task, [])
+ if not joints:
+ return None
+
+ positions = self.get_joint_positions()
+ if not positions:
+ return None
+
+ return [positions.get(j, 0.0) for j in joints]
+
+ def generate_trajectory(
+ self, waypoints: list[list[float]], task_name: str | None = None
+ ) -> JointTrajectory | None:
+ """Generate trajectory from waypoints using trapezoidal velocity profile."""
+ task = task_name or self._current_task
+ if task is None:
+ print("Error: No task selected")
+ return None
+
+ generator = self._generators.get(task)
+ if generator is None:
+ print(f"Error: No generator for task '{task}'. Call select_task() first.")
+ return None
+
+ return generator.generate(waypoints)
+
+ def set_velocity_limit(self, velocity: float, task_name: str | None = None) -> None:
+ """Set max velocity for trajectory generation."""
+ task = task_name or self._current_task
+ if task and task in self._generators:
+ gen = self._generators[task]
+ gen.set_limits(velocity, gen.max_acceleration)
+
+ def set_acceleration_limit(self, acceleration: float, task_name: str | None = None) -> None:
+ """Set max acceleration for trajectory generation."""
+ task = task_name or self._current_task
+ if task and task in self._generators:
+ gen = self._generators[task]
+ gen.set_limits(gen.max_velocity, acceleration)
+
+
+# =============================================================================
+# Interactive CLI
+# =============================================================================
+
+
+def parse_joint_input(line: str, num_joints: int) -> list[float] | None:
+ """Parse joint positions from user input (degrees by default, 'r' suffix for radians)."""
+ parts = line.strip().split()
+ if len(parts) != num_joints:
+ return None
+
+ positions = []
+ for part in parts:
+ try:
+ if part.endswith("r"):
+ positions.append(float(part[:-1]))
+ else:
+ positions.append(math.radians(float(part)))
+ except ValueError:
+ return None
+
+ return positions
+
+
+def format_positions(positions: list[float], as_degrees: bool = True) -> str:
+ """Format positions for display."""
+ if as_degrees:
+ return "[" + ", ".join(f"{math.degrees(p):.1f}" for p in positions) + "] deg"
+ return "[" + ", ".join(f"{p:.3f}" for p in positions) + "] rad"
+
+
+def preview_waypoints(waypoints: list[list[float]], joint_names: list[str]) -> None:
+ """Show waypoints list."""
+ if not waypoints:
+ print("No waypoints")
+ return
+
+ print(f"\nWaypoints ({len(waypoints)}):")
+ print("-" * 70)
+
+ # Header with joint names (truncated)
+ headers = [j.split("_")[-1][:6] for j in joint_names] # e.g., "joint1" -> "joint1"
+ header_str = " ".join(f"{h:>7}" for h in headers)
+ print(f" # | {header_str} (degrees)")
+ print("-" * 70)
+
+ for i, joints in enumerate(waypoints):
+ deg = [f"{math.degrees(j):7.1f}" for j in joints]
+ print(f" {i + 1:2} | {' '.join(deg)}")
+ print("-" * 70)
+
+
+def preview_trajectory(trajectory: JointTrajectory, joint_names: list[str]) -> None:
+ """Show generated trajectory preview."""
+ headers = [j.split("_")[-1][:6] for j in joint_names]
+ header_str = " ".join(f"{h:>7}" for h in headers)
+
+ print("\n" + "=" * 70)
+ print("GENERATED TRAJECTORY")
+ print("=" * 70)
+ print(f"Duration: {trajectory.duration:.3f}s")
+ print(f"Points: {len(trajectory.points)}")
+ print("-" * 70)
+ print(f"{'Time':>6} | {header_str} (degrees)")
+ print("-" * 70)
+
+ num_samples = min(10, max(len(trajectory.points) // 10, 5))
+ for i in range(num_samples + 1):
+ t = (i / num_samples) * trajectory.duration
+ q_ref, _ = trajectory.sample(t)
+ q_deg = [f"{math.degrees(q):7.1f}" for q in q_ref]
+ print(f"{t:6.2f} | {' '.join(q_deg)}")
+
+ print("=" * 70)
+
+
+def wait_for_completion(client: OrchestratorClient, task_name: str, timeout: float = 60.0) -> bool:
+ """Wait for trajectory to complete with progress display."""
+ start = time.time()
+ last_progress = -1.0
+
+ while time.time() - start < timeout:
+ status = client.get_trajectory_status(task_name)
+ if not status.get("active", False):
+ state: str = status.get("state", "UNKNOWN")
+ print(f"\nTrajectory finished: {state}")
+ return state == "COMPLETED"
+
+ progress = status.get("progress", 0.0)
+ if progress != last_progress:
+ bar_len = 30
+ filled = int(bar_len * progress)
+ bar = "=" * filled + "-" * (bar_len - filled)
+ print(f"\r[{bar}] {progress * 100:.1f}%", end="", flush=True)
+ last_progress = progress
+
+ time.sleep(0.05)
+
+ print("\nTimeout waiting for trajectory")
+ return False
+
+
+class OrchestratorShell:
+ """IPython shell interface for orchestrator control."""
+
+ def __init__(self, client: OrchestratorClient, initial_task: str) -> None:
+ self._client = client
+ self._current_task = initial_task
+ self._waypoints: list[list[float]] = []
+ self._generated_trajectory: JointTrajectory | None = None
+
+ if not client.select_task(initial_task):
+ raise ValueError(f"Failed to select task: {initial_task}")
+
+ def _joints(self) -> list[str]:
+ return self._client.get_task_joints(self._current_task)
+
+ def _num_joints(self) -> int:
+ return len(self._joints())
+
+ def help(self) -> None:
+ """Show available commands."""
+ print("\nOrchestrator Client Commands:")
+ print("=" * 60)
+ print("Waypoint Commands:")
+ print(" here() - Add current position as waypoint")
+ print(" add(j1, j2, ...) - Add waypoint (degrees)")
+ print(" waypoints() - List all waypoints")
+ print(" delete(n) - Delete waypoint n")
+ print(" clear() - Clear all waypoints")
+ print("\nTrajectory Commands:")
+ print(" preview() - Preview generated trajectory")
+ print(" run() - Execute trajectory")
+ print(" status() - Show task status")
+ print(" cancel() - Cancel active trajectory")
+ print("\nMulti-Arm Commands:")
+ print(" tasks() - List all tasks")
+ print(" switch('task_name') - Switch to different task")
+ print(" hw() - List hardware")
+ print(" joints() - List joints for current task")
+ print("\nSettings:")
+ print(" current() - Show current joint positions")
+ print(" vel(value) - Set max velocity (rad/s)")
+ print(" accel(value) - Set max acceleration (rad/s^2)")
+ print("=" * 60)
+
+ def here(self) -> None:
+ """Add current position as waypoint."""
+ positions = self._client.get_current_positions(self._current_task)
+ if positions:
+ self._waypoints.append(positions)
+ self._generated_trajectory = None
+ print(f"Added waypoint {len(self._waypoints)}: {format_positions(positions)}")
+ else:
+ print("Could not get current positions")
+
+ def add(self, *joints: float) -> None:
+ """Add waypoint with specified joint values (in degrees)."""
+ num_joints = self._num_joints()
+ if len(joints) != num_joints:
+ print(f"Need {num_joints} joint values, got {len(joints)}")
+ return
+
+ rad_joints = [math.radians(j) for j in joints]
+ self._waypoints.append(rad_joints)
+ self._generated_trajectory = None
+ print(f"Added waypoint {len(self._waypoints)}: {format_positions(rad_joints)}")
+
+ def waypoints(self) -> None:
+ """List all waypoints."""
+ preview_waypoints(self._waypoints, self._joints())
+
+ def delete(self, index: int) -> None:
+ """Delete a waypoint by index (1-based)."""
+ idx = index - 1
+ if 0 <= idx < len(self._waypoints):
+ self._waypoints.pop(idx)
+ self._generated_trajectory = None
+ print(f"Deleted waypoint {index}")
+ else:
+ print(f"Invalid index (1-{len(self._waypoints)})")
+
+ def clear(self) -> None:
+ """Clear all waypoints."""
+ self._waypoints.clear()
+ self._generated_trajectory = None
+ print("Cleared waypoints")
+
+ def preview(self) -> None:
+ """Preview generated trajectory."""
+ if len(self._waypoints) < 2:
+ print("Need at least 2 waypoints")
+ return
+ try:
+ self._generated_trajectory = self._client.generate_trajectory(
+ self._waypoints, self._current_task
+ )
+ if self._generated_trajectory:
+ preview_trajectory(self._generated_trajectory, self._joints())
+ except Exception as e:
+ print(f"Error: {e}")
+
+ def run(self) -> None:
+ """Execute trajectory."""
+ if len(self._waypoints) < 2:
+ print("Need at least 2 waypoints")
+ return
+
+ if self._generated_trajectory is None:
+ self._generated_trajectory = self._client.generate_trajectory(
+ self._waypoints, self._current_task
+ )
+
+ if self._generated_trajectory is None:
+ print("Failed to generate trajectory")
+ return
+
+ preview_trajectory(self._generated_trajectory, self._joints())
+ confirm = input("\nExecute? [y/N]: ").strip().lower()
+ if confirm == "y":
+ if self._client.execute_trajectory(self._current_task, self._generated_trajectory):
+ print("Trajectory started...")
+ wait_for_completion(self._client, self._current_task)
+ else:
+ print("Failed to start trajectory")
+
+ def status(self) -> None:
+ """Show task status."""
+ status = self._client.get_trajectory_status(self._current_task)
+ print(f"\nTask: {self._current_task}")
+ print(f" Active: {status.get('active', False)}")
+ print(f" State: {status.get('state', 'UNKNOWN')}")
+ if "progress" in status:
+ print(f" Progress: {status['progress'] * 100:.1f}%")
+
+ def cancel(self) -> None:
+ """Cancel active trajectory."""
+ if self._client.cancel_trajectory(self._current_task):
+ print("Cancelled")
+ else:
+ print("Cancel failed")
+
+ def tasks(self) -> None:
+ """List all tasks."""
+ all_tasks = self._client.list_tasks()
+ active = self._client.get_active_tasks()
+ print("\nTasks:")
+ for t in all_tasks:
+ marker = "* " if t == self._current_task else " "
+ active_marker = " [ACTIVE]" if t in active else ""
+ t_joints = self._client.get_task_joints(t)
+ joint_count = len(t_joints) if t_joints else "?"
+ print(f"{marker}{t} ({joint_count} joints){active_marker}")
+
+ def switch(self, task_name: str) -> None:
+ """Switch to a different task."""
+ if self._client.select_task(task_name):
+ self._current_task = task_name
+ self._waypoints.clear()
+ self._generated_trajectory = None
+ joints = self._joints()
+ print(f"Switched to {self._current_task} ({len(joints)} joints)")
+ print(f"Joints: {', '.join(joints)}")
+ else:
+ print(f"Failed to switch to {task_name}")
+
+ def hw(self) -> None:
+ """List hardware."""
+ hardware = self._client.list_hardware()
+ print(f"\nHardware: {', '.join(hardware)}")
+
+ def joints(self) -> None:
+ """List joints for current task."""
+ joints = self._joints()
+ print(f"\nJoints for {self._current_task}:")
+ for i, j in enumerate(joints):
+ pos = self._client.get_joint_positions().get(j, 0.0)
+ print(f" {i + 1}. {j}: {math.degrees(pos):.1f} deg")
+
+ def current(self) -> None:
+ """Show current joint positions."""
+ positions = self._client.get_current_positions(self._current_task)
+ if positions:
+ print(f"Current: {format_positions(positions)}")
+ else:
+ print("Could not get positions")
+
+ def vel(self, value: float | None = None) -> None:
+ """Set or show max velocity (rad/s)."""
+ if value is None:
+ gen = self._client._generators.get(self._current_task)
+ if gen:
+ print(f"Max velocity: {gen.max_velocity[0]:.2f} rad/s")
+ return
+
+ if value <= 0:
+ print("Velocity must be positive")
+ return
+
+ self._client.set_velocity_limit(value, self._current_task)
+ self._generated_trajectory = None
+ print(f"Max velocity: {value:.2f} rad/s")
+
+ def accel(self, value: float | None = None) -> None:
+ """Set or show max acceleration (rad/s^2)."""
+ if value is None:
+ gen = self._client._generators.get(self._current_task)
+ if gen:
+ print(f"Max acceleration: {gen.max_acceleration[0]:.2f} rad/s^2")
+ return
+
+ if value <= 0:
+ print("Acceleration must be positive")
+ return
+
+ self._client.set_acceleration_limit(value, self._current_task)
+ self._generated_trajectory = None
+ print(f"Max acceleration: {value:.2f} rad/s^2")
+
+
+def interactive_mode(client: OrchestratorClient, initial_task: str) -> None:
+ """Start IPython interactive mode."""
+ import IPython
+
+ shell = OrchestratorShell(client, initial_task)
+
+ print("\n" + "=" * 60)
+ print(f"Orchestrator Client (IPython) - Task: {initial_task}")
+ print("=" * 60)
+ print(f"Joints: {', '.join(shell._joints())}")
+ print("\nType help() for available commands")
+ print("=" * 60 + "\n")
+
+ IPython.start_ipython( # type: ignore[no-untyped-call]
+ argv=[],
+ user_ns={
+ "help": shell.help,
+ "here": shell.here,
+ "add": shell.add,
+ "waypoints": shell.waypoints,
+ "delete": shell.delete,
+ "clear": shell.clear,
+ "preview": shell.preview,
+ "run": shell.run,
+ "status": shell.status,
+ "cancel": shell.cancel,
+ "tasks": shell.tasks,
+ "switch": shell.switch,
+ "hw": shell.hw,
+ "joints": shell.joints,
+ "current": shell.current,
+ "vel": shell.vel,
+ "accel": shell.accel,
+ "client": client,
+ "shell": shell,
+ },
+ )
+
+
+def _run_client(client: OrchestratorClient, task: str, vel: float, accel: float) -> int:
+ """Run the client with the given configuration."""
+ try:
+ hardware = client.list_hardware()
+ tasks = client.list_tasks()
+
+ if not hardware:
+ print("\nWarning: No hardware found. Is the orchestrator running?")
+ print("Start with: dimos run orchestrator-mock")
+ response = input("Continue anyway? [y/N]: ").strip().lower()
+ if response != "y":
+ return 0
+ else:
+ print(f"Hardware: {', '.join(hardware)}")
+ print(f"Tasks: {', '.join(tasks)}")
+
+ except Exception as e:
+ print(f"\nConnection error: {e}")
+ print("Make sure orchestrator is running: dimos run orchestrator-mock")
+ return 1
+
+ if task not in tasks and tasks:
+ print(f"\nTask '{task}' not found.")
+ print(f"Available: {', '.join(tasks)}")
+ task = tasks[0]
+ print(f"Using '{task}'")
+
+ if client.select_task(task):
+ client.set_velocity_limit(vel, task)
+ client.set_acceleration_limit(accel, task)
+
+ interactive_mode(client, task)
+ return 0
+
+
+def main() -> int:
+ """Main entry point."""
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description="Interactive client for ControlOrchestrator",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Single arm (with orchestrator-mock running)
+ python -m dimos.manipulation.control.orchestrator_client
+
+ # Dual arm - control left arm
+ python -m dimos.manipulation.control.orchestrator_client --task traj_left
+
+ # Dual arm - control right arm
+ python -m dimos.manipulation.control.orchestrator_client --task traj_right
+ """,
+ )
+ parser.add_argument(
+ "--task",
+ type=str,
+ default="traj_arm",
+ help="Initial task to control (default: traj_arm)",
+ )
+ parser.add_argument(
+ "--vel",
+ type=float,
+ default=1.0,
+ help="Max velocity in rad/s (default: 1.0)",
+ )
+ parser.add_argument(
+ "--accel",
+ type=float,
+ default=2.0,
+ help="Max acceleration in rad/s^2 (default: 2.0)",
+ )
+ args = parser.parse_args()
+
+ print("\n" + "=" * 70)
+ print("Orchestrator Client")
+ print("=" * 70)
+ print("\nConnecting to ControlOrchestrator via RPC...")
+
+ client = OrchestratorClient()
+ try:
+ return _run_client(client, args.task, args.vel, args.accel)
+ finally:
+ client.stop()
+
+
+if __name__ == "__main__":
+ try:
+ sys.exit(main())
+ except KeyboardInterrupt:
+ print("\n\nInterrupted")
+ sys.exit(0)
+ except Exception as e:
+ print(f"\nError: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py
index cfbdb77cbf..f5a0810803 100644
--- a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py
+++ b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py
@@ -34,7 +34,6 @@
from dimos.core import In, Module, Out, rpc
from dimos.core.module import ModuleConfig
-from dimos.hardware.manipulators.xarm.spec import ArmDriverSpec
from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Twist, Vector3
from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState
from dimos.utils.logging_config import setup_logger
@@ -90,7 +89,7 @@ class CartesianMotionController(Module):
5. Publishing joint commands to the driver
The controller is hardware-agnostic: it works with any arm driver that
- implements the ArmDriverSpec protocol (provides IK/FK RPC methods).
+ provides IK/FK RPC methods and JointState/RobotState outputs.
"""
default_config = CartesianMotionControllerConfig
@@ -112,12 +111,12 @@ class CartesianMotionController(Module):
cartesian_velocity: Out[Twist] = None # type: ignore[assignment]
current_pose: Out[PoseStamped] = None # type: ignore[assignment]
- def __init__(self, arm_driver: ArmDriverSpec | None = None, *args: Any, **kwargs: Any) -> None:
+ def __init__(self, arm_driver: Any = None, *args: Any, **kwargs: Any) -> None:
"""
Initialize the Cartesian motion controller.
Args:
- arm_driver: (Optional) Hardware driver implementing ArmDriverSpec protocol.
+ arm_driver: (Optional) Hardware driver reference (legacy mode).
When using blueprints, this is resolved automatically via rpc_calls.
"""
super().__init__(*args, **kwargs)
diff --git a/dimos/manipulation/control/servo_control/example_cartesian_control.py b/dimos/manipulation/control/servo_control/example_cartesian_control.py
deleted file mode 100644
index eeff04e424..0000000000
--- a/dimos/manipulation/control/servo_control/example_cartesian_control.py
+++ /dev/null
@@ -1,194 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Example: Topic-Based Cartesian Motion Control with xArm
-
-Demonstrates topic-based Cartesian space motion control. The controller
-subscribes to /target_pose and automatically moves to received targets.
-
-This example shows:
-1. Deploy xArm driver with LCM transports
-2. Deploy CartesianMotionController with LCM transports
-3. Configure controller to subscribe to /target_pose topic
-4. Keep system running to process incoming targets
-
-Use target_setter.py to publish target poses to /target_pose topic.
-
-Pattern matches: interactive_control.py + sample_trajectory_generator.py
-"""
-
-import signal
-import time
-
-from dimos import core
-from dimos.hardware.manipulators.xarm import XArmDriver
-from dimos.manipulation.control import CartesianMotionController
-from dimos.msgs.geometry_msgs import PoseStamped
-from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState
-
-# Global flag for graceful shutdown
-shutdown_requested = False
-
-
-def signal_handler(sig, frame): # type: ignore[no-untyped-def]
- """Handle Ctrl+C for graceful shutdown."""
- global shutdown_requested
- print("\n\nShutdown requested...")
- shutdown_requested = True
-
-
-def main(): # type: ignore[no-untyped-def]
- """
- Deploy and run topic-based Cartesian motion control system.
-
- The system subscribes to /target_pose and automatically moves
- the robot to received target poses.
- """
-
- # Register signal handler for graceful shutdown
- signal.signal(signal.SIGINT, signal_handler)
- signal.signal(signal.SIGTERM, signal_handler)
-
- # =========================================================================
- # Step 1: Start dimos cluster
- # =========================================================================
- print("=" * 80)
- print("Topic-Based Cartesian Motion Control")
- print("=" * 80)
- print("\nStarting dimos cluster...")
- dimos = core.start(1) # Start with 1 worker
-
- try:
- # =========================================================================
- # Step 2: Deploy xArm driver
- # =========================================================================
- print("\nDeploying xArm driver...")
- arm_driver = dimos.deploy( # type: ignore[attr-defined]
- XArmDriver,
- ip_address="192.168.1.210",
- xarm_type="xarm6",
- report_type="dev",
- enable_on_start=True,
- )
-
- # Set up driver transports
- arm_driver.joint_state.transport = core.LCMTransport("/xarm/joint_states", JointState)
- arm_driver.robot_state.transport = core.LCMTransport("/xarm/robot_state", RobotState)
- arm_driver.joint_position_command.transport = core.LCMTransport(
- "/xarm/joint_position_command", JointCommand
- )
- arm_driver.joint_velocity_command.transport = core.LCMTransport(
- "/xarm/joint_velocity_command", JointCommand
- )
-
- print("Starting xArm driver...")
- arm_driver.start()
-
- # =========================================================================
- # Step 3: Deploy Cartesian motion controller
- # =========================================================================
- print("\nDeploying Cartesian motion controller...")
- controller = dimos.deploy( # type: ignore[attr-defined]
- CartesianMotionController,
- arm_driver=arm_driver,
- control_frequency=20.0,
- position_kp=1.0,
- position_kd=0.1,
- orientation_kp=2.0,
- orientation_kd=0.2,
- max_linear_velocity=0.15,
- max_angular_velocity=0.8,
- position_tolerance=0.002,
- orientation_tolerance=0.02,
- velocity_control_mode=True,
- )
-
- # Set up controller transports
- controller.joint_state.transport = core.LCMTransport("/xarm/joint_states", JointState)
- controller.robot_state.transport = core.LCMTransport("/xarm/robot_state", RobotState)
- controller.joint_position_command.transport = core.LCMTransport(
- "/xarm/joint_position_command", JointCommand
- )
-
- # IMPORTANT: Configure controller to subscribe to /target_pose topic
- controller.target_pose.transport = core.LCMTransport("/target_pose", PoseStamped)
-
- # Publish current pose for target setters to use
- controller.current_pose.transport = core.LCMTransport("/xarm/current_pose", PoseStamped)
-
- print("Starting controller...")
- controller.start()
-
- # =========================================================================
- # Step 4: Keep system running
- # =========================================================================
- print("\n" + "=" * 80)
- print("✓ System ready!")
- print("=" * 80)
- print("\nController is now listening to /target_pose topic")
- print("Use target_setter.py to publish target poses")
- print("\nPress Ctrl+C to shutdown")
- print("=" * 80 + "\n")
-
- # Keep running until shutdown requested
- while not shutdown_requested:
- time.sleep(0.5)
-
- # =========================================================================
- # Step 5: Clean shutdown
- # =========================================================================
- print("\nShutting down...")
- print("Stopping controller...")
- controller.stop()
- print("Stopping driver...")
- arm_driver.stop()
- print("✓ Shutdown complete")
-
- finally:
- # Always stop dimos cluster
- print("Stopping dimos cluster...")
- dimos.stop() # type: ignore[attr-defined]
-
-
-if __name__ == "__main__":
- """
- Topic-Based Cartesian Control for xArm.
-
- Usage:
- # Terminal 1: Start the controller (this script)
- python3 example_cartesian_control.py
-
- # Terminal 2: Publish target poses
- python3 target_setter.py --world 0.4 0.0 0.5 # Absolute world coordinates
- python3 target_setter.py --relative 0.05 0 0 # Relative movement (50mm in X)
-
- The controller subscribes to /target_pose topic and automatically moves
- the robot to received target poses.
-
- Requirements:
- - xArm robot connected at 192.168.2.235
- - Robot will be automatically enabled in servo mode
- - Proper network configuration
- """
- try:
- main() # type: ignore[no-untyped-call]
- except KeyboardInterrupt:
- print("\n\nInterrupted by user")
- except Exception as e:
- print(f"\nError: {e}")
- import traceback
-
- traceback.print_exc()
diff --git a/dimos/manipulation/control/trajectory_controller/example_trajectory_control.py b/dimos/manipulation/control/trajectory_controller/example_trajectory_control.py
deleted file mode 100644
index 100e095a45..0000000000
--- a/dimos/manipulation/control/trajectory_controller/example_trajectory_control.py
+++ /dev/null
@@ -1,189 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Example: Joint Trajectory Control with xArm
-
-Demonstrates joint-space trajectory execution. The controller
-executes trajectories by sampling at 100Hz and sending joint commands.
-
-This example shows:
-1. Deploy xArm driver with LCM transports
-2. Deploy JointTrajectoryController with LCM transports
-3. Execute trajectories via RPC or topic
-4. Monitor execution status
-
-Use trajectory_setter.py to interactively create and execute trajectories.
-"""
-
-import signal
-import time
-
-from dimos import core
-from dimos.hardware.manipulators.xarm import XArmDriver
-from dimos.manipulation.control import JointTrajectoryController
-from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState
-from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryState
-
-# Global flag for graceful shutdown
-shutdown_requested = False
-
-
-def signal_handler(sig, frame): # type: ignore[no-untyped-def]
- """Handle Ctrl+C for graceful shutdown."""
- global shutdown_requested
- print("\n\nShutdown requested...")
- shutdown_requested = True
-
-
-def main(): # type: ignore[no-untyped-def]
- """
- Deploy and run joint trajectory control system.
-
- The system executes joint trajectories at 100Hz by sampling
- and forwarding joint positions to the arm driver.
- """
-
- # Register signal handler for graceful shutdown
- signal.signal(signal.SIGINT, signal_handler)
- signal.signal(signal.SIGTERM, signal_handler)
-
- # =========================================================================
- # Step 1: Start dimos cluster
- # =========================================================================
- print("=" * 80)
- print("Joint Trajectory Control")
- print("=" * 80)
- print("\nStarting dimos cluster...")
- dimos = core.start(1) # Start with 1 worker
-
- try:
- # =========================================================================
- # Step 2: Deploy xArm driver
- # =========================================================================
- print("\nDeploying xArm driver...")
- arm_driver = dimos.deploy( # type: ignore[attr-defined]
- XArmDriver,
- ip_address="192.168.1.210",
- xarm_type="xarm6",
- report_type="dev",
- enable_on_start=True,
- )
-
- # Set up driver transports
- arm_driver.joint_state.transport = core.LCMTransport("/xarm/joint_states", JointState)
- arm_driver.robot_state.transport = core.LCMTransport("/xarm/robot_state", RobotState)
- arm_driver.joint_position_command.transport = core.LCMTransport(
- "/xarm/joint_position_command", JointCommand
- )
-
- print("Starting xArm driver...")
- arm_driver.start()
-
- # =========================================================================
- # Step 3: Deploy Joint Trajectory Controller
- # =========================================================================
- print("\nDeploying Joint Trajectory Controller...")
- controller = dimos.deploy( # type: ignore[attr-defined]
- JointTrajectoryController,
- control_frequency=100.0, # 100Hz execution
- )
-
- # Set up controller transports
- controller.joint_state.transport = core.LCMTransport("/xarm/joint_states", JointState)
- controller.robot_state.transport = core.LCMTransport("/xarm/robot_state", RobotState)
- controller.joint_position_command.transport = core.LCMTransport(
- "/xarm/joint_position_command", JointCommand
- )
-
- # Subscribe to trajectory topic (from trajectory_setter.py)
- controller.trajectory.transport = core.LCMTransport("/trajectory", JointTrajectory)
-
- print("Starting controller...")
- controller.start()
-
- # Wait for joint state
- print("\nWaiting for joint state...")
- time.sleep(1.0)
-
- # =========================================================================
- # Step 4: Keep system running
- # =========================================================================
- print("\n" + "=" * 80)
- print("System ready!")
- print("=" * 80)
- print("\nJoint Trajectory Controller is running at 100Hz")
- print("Listening on /trajectory topic")
- print("\nUse trajectory_setter.py in another terminal to publish trajectories")
- print("\nPress Ctrl+C to shutdown")
- print("=" * 80 + "\n")
-
- # Keep running until shutdown requested
- while not shutdown_requested:
- # Print status periodically
- status = controller.get_status()
- if status.state == TrajectoryState.EXECUTING:
- print(
- f"\rExecuting: {status.progress:.1%} | "
- f"elapsed={status.time_elapsed:.2f}s | "
- f"remaining={status.time_remaining:.2f}s",
- end="",
- )
- time.sleep(0.5)
-
- # =========================================================================
- # Step 5: Clean shutdown
- # =========================================================================
- print("\n\nShutting down...")
- print("Stopping controller...")
- controller.stop()
- print("Stopping driver...")
- arm_driver.stop()
- print("Shutdown complete")
-
- finally:
- # Always stop dimos cluster
- print("Stopping dimos cluster...")
- dimos.stop() # type: ignore[attr-defined]
-
-
-if __name__ == "__main__":
- """
- Joint Trajectory Control for xArm.
-
- Usage:
- # Terminal 1: Start the controller (this script)
- python3 example_trajectory_control.py
-
- # Terminal 2: Create and execute trajectories
- python3 trajectory_setter.py
-
- The controller executes joint trajectories at 100Hz by sampling
- and forwarding joint positions to the arm driver.
-
- Requirements:
- - xArm robot connected at 192.168.1.210
- - Robot will be automatically enabled in servo mode
- - Proper network configuration
- """
- try:
- main() # type: ignore[no-untyped-call]
- except KeyboardInterrupt:
- print("\n\nInterrupted by user")
- except Exception as e:
- print(f"\nError: {e}")
- import traceback
-
- traceback.print_exc()
diff --git a/dimos/manipulation/manip_aio_pipeline.py b/dimos/manipulation/manip_aio_pipeline.py
deleted file mode 100644
index fe3598ab1e..0000000000
--- a/dimos/manipulation/manip_aio_pipeline.py
+++ /dev/null
@@ -1,592 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Asynchronous, reactive manipulation pipeline for realtime detection, filtering, and grasp generation.
-"""
-
-import asyncio
-import json
-import threading
-import time
-
-import cv2
-import numpy as np
-import reactivex as rx
-import reactivex.operators as ops
-import websockets
-
-from dimos.perception.common.utils import colorize_depth
-from dimos.perception.detection2d.detic_2d_det import ( # type: ignore[import-not-found, import-untyped]
- Detic2DDetector,
-)
-from dimos.perception.grasp_generation.utils import draw_grasps_on_image
-from dimos.perception.object_detection_stream import ObjectDetectionStream
-from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering
-from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class ManipulationPipeline:
- """
- Clean separated stream pipeline with frame buffering.
-
- - Object detection runs independently on RGB stream
- - Point cloud processing subscribes to both detection and ZED streams separately
- - Simple frame buffering to match RGB+depth+objects
- """
-
- def __init__(
- self,
- camera_intrinsics: list[float], # [fx, fy, cx, cy]
- min_confidence: float = 0.6,
- max_objects: int = 10,
- vocabulary: str | None = None,
- grasp_server_url: str | None = None,
- enable_grasp_generation: bool = False,
- ) -> None:
- """
- Initialize the manipulation pipeline.
-
- Args:
- camera_intrinsics: [fx, fy, cx, cy] camera parameters
- min_confidence: Minimum detection confidence threshold
- max_objects: Maximum number of objects to process
- vocabulary: Optional vocabulary for Detic detector
- grasp_server_url: Optional WebSocket URL for Dimensional Grasp server
- enable_grasp_generation: Whether to enable async grasp generation
- """
- self.camera_intrinsics = camera_intrinsics
- self.min_confidence = min_confidence
-
- # Grasp generation settings
- self.grasp_server_url = grasp_server_url
- self.enable_grasp_generation = enable_grasp_generation
-
- # Asyncio event loop for WebSocket communication
- self.grasp_loop = None
- self.grasp_loop_thread = None
-
- # Storage for grasp results and filtered objects
- self.latest_grasps: list[dict] = [] # type: ignore[type-arg] # Simplified: just a list of grasps
- self.grasps_consumed = False
- self.latest_filtered_objects = [] # type: ignore[var-annotated]
- self.latest_rgb_for_grasps = None # Store RGB image for grasp overlay
- self.grasp_lock = threading.Lock()
-
- # Track pending requests - simplified to single task
- self.grasp_task: asyncio.Task | None = None # type: ignore[type-arg]
-
- # Reactive subjects for streaming filtered objects and grasps
- self.filtered_objects_subject = rx.subject.Subject() # type: ignore[var-annotated]
- self.grasps_subject = rx.subject.Subject() # type: ignore[var-annotated]
- self.grasp_overlay_subject = rx.subject.Subject() # type: ignore[var-annotated] # Add grasp overlay subject
-
- # Initialize grasp client if enabled
- if self.enable_grasp_generation and self.grasp_server_url:
- self._start_grasp_loop()
-
- # Initialize object detector
- self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence)
-
- # Initialize point cloud processor
- self.pointcloud_filter = PointcloudFiltering(
- color_intrinsics=camera_intrinsics,
- depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics
- max_num_objects=max_objects,
- )
-
- logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}")
-
- def create_streams(self, zed_stream: rx.Observable) -> dict[str, rx.Observable]: # type: ignore[type-arg]
- """
- Create streams using exact old main logic.
- """
- # Create ZED streams (from old main)
- zed_frame_stream = zed_stream.pipe(ops.share())
-
- # RGB stream for object detection (from old main)
- video_stream = zed_frame_stream.pipe(
- ops.map(lambda x: x.get("rgb") if x is not None else None), # type: ignore[attr-defined]
- ops.filter(lambda x: x is not None),
- ops.share(),
- )
- object_detector = ObjectDetectionStream(
- camera_intrinsics=self.camera_intrinsics,
- min_confidence=self.min_confidence,
- class_filter=None,
- detector=self.detector,
- video_stream=video_stream,
- disable_depth=True,
- )
-
- # Store latest frames for point cloud processing (from old main)
- latest_rgb = None
- latest_depth = None
- latest_point_cloud_overlay = None
- frame_lock = threading.Lock()
-
- # Subscribe to combined ZED frames (from old main)
- def on_zed_frame(zed_data) -> None: # type: ignore[no-untyped-def]
- nonlocal latest_rgb, latest_depth
- if zed_data is not None:
- with frame_lock:
- latest_rgb = zed_data.get("rgb")
- latest_depth = zed_data.get("depth")
-
- # Depth stream for point cloud filtering (from old main)
- def get_depth_or_overlay(zed_data): # type: ignore[no-untyped-def]
- if zed_data is None:
- return None
-
- # Check if we have a point cloud overlay available
- with frame_lock:
- overlay = latest_point_cloud_overlay
-
- if overlay is not None:
- return overlay
- else:
- # Return regular colorized depth
- return colorize_depth(zed_data.get("depth"), max_depth=10.0)
-
- depth_stream = zed_frame_stream.pipe(
- ops.map(get_depth_or_overlay), ops.filter(lambda x: x is not None), ops.share()
- )
-
- # Process object detection results with point cloud filtering (from old main)
- def on_detection_next(result) -> None: # type: ignore[no-untyped-def]
- nonlocal latest_point_cloud_overlay
- if result.get("objects"):
- # Get latest RGB and depth frames
- with frame_lock:
- rgb = latest_rgb
- depth = latest_depth
-
- if rgb is not None and depth is not None:
- try:
- filtered_objects = self.pointcloud_filter.process_images(
- rgb, depth, result["objects"]
- )
-
- if filtered_objects:
- # Store filtered objects
- with self.grasp_lock:
- self.latest_filtered_objects = filtered_objects
- self.filtered_objects_subject.on_next(filtered_objects)
-
- # Create base image (colorized depth)
- base_image = colorize_depth(depth, max_depth=10.0)
-
- # Create point cloud overlay visualization
- overlay_viz = create_point_cloud_overlay_visualization(
- base_image=base_image, # type: ignore[arg-type]
- objects=filtered_objects, # type: ignore[arg-type]
- intrinsics=self.camera_intrinsics, # type: ignore[arg-type]
- )
-
- # Store the overlay for the stream
- with frame_lock:
- latest_point_cloud_overlay = overlay_viz
-
- # Request grasps if enabled
- if self.enable_grasp_generation and len(filtered_objects) > 0:
- # Save RGB image for later grasp overlay
- with frame_lock:
- self.latest_rgb_for_grasps = rgb.copy()
-
- task = self.request_scene_grasps(filtered_objects) # type: ignore[arg-type]
- if task:
- # Check for results after a delay
- def check_grasps_later() -> None:
- time.sleep(2.0) # Wait for grasp processing
- # Wait for task to complete
- if hasattr(self, "grasp_task") and self.grasp_task:
- try:
- self.grasp_task.result( # type: ignore[call-arg]
- timeout=3.0
- ) # Get result with timeout
- except Exception as e:
- logger.warning(f"Grasp task failed or timeout: {e}")
-
- # Try to get latest grasps and create overlay
- with self.grasp_lock:
- grasps = self.latest_grasps
-
- if grasps and hasattr(self, "latest_rgb_for_grasps"):
- # Create grasp overlay on the saved RGB image
- try:
- bgr_image = cv2.cvtColor( # type: ignore[call-overload]
- self.latest_rgb_for_grasps, cv2.COLOR_RGB2BGR
- )
- result_bgr = draw_grasps_on_image(
- bgr_image,
- grasps,
- self.camera_intrinsics,
- max_grasps=-1, # Show all grasps
- )
- result_rgb = cv2.cvtColor(
- result_bgr, cv2.COLOR_BGR2RGB
- )
-
- # Emit grasp overlay immediately
- self.grasp_overlay_subject.on_next(result_rgb)
-
- except Exception as e:
- logger.error(f"Error creating grasp overlay: {e}")
-
- # Emit grasps to stream
- self.grasps_subject.on_next(grasps)
-
- threading.Thread(target=check_grasps_later, daemon=True).start()
- else:
- logger.warning("Failed to create grasp task")
- except Exception as e:
- logger.error(f"Error in point cloud filtering: {e}")
- with frame_lock:
- latest_point_cloud_overlay = None
-
- def on_error(error) -> None: # type: ignore[no-untyped-def]
- logger.error(f"Error in stream: {error}")
-
- def on_completed() -> None:
- logger.info("Stream completed")
-
- def start_subscriptions() -> None:
- """Start subscriptions in background thread (from old main)"""
- # Subscribe to combined ZED frames
- zed_frame_stream.subscribe(on_next=on_zed_frame)
-
- # Start subscriptions in background thread (from old main)
- subscription_thread = threading.Thread(target=start_subscriptions, daemon=True)
- subscription_thread.start()
- time.sleep(2) # Give subscriptions time to start
-
- # Subscribe to object detection stream (from old main)
- object_detector.get_stream().subscribe( # type: ignore[no-untyped-call]
- on_next=on_detection_next, on_error=on_error, on_completed=on_completed
- )
-
- # Create visualization stream for web interface (from old main)
- viz_stream = object_detector.get_stream().pipe( # type: ignore[no-untyped-call]
- ops.map(lambda x: x["viz_frame"] if x is not None else None), # type: ignore[index]
- ops.filter(lambda x: x is not None),
- )
-
- # Create filtered objects stream
- filtered_objects_stream = self.filtered_objects_subject
-
- # Create grasps stream
- grasps_stream = self.grasps_subject
-
- # Create grasp overlay subject for immediate emission
- grasp_overlay_stream = self.grasp_overlay_subject
-
- return {
- "detection_viz": viz_stream,
- "pointcloud_viz": depth_stream,
- "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), # type: ignore[attr-defined, no-untyped-call]
- "filtered_objects": filtered_objects_stream,
- "grasps": grasps_stream,
- "grasp_overlay": grasp_overlay_stream,
- }
-
- def _start_grasp_loop(self) -> None:
- """Start asyncio event loop in a background thread for WebSocket communication."""
-
- def run_loop() -> None:
- self.grasp_loop = asyncio.new_event_loop() # type: ignore[assignment]
- asyncio.set_event_loop(self.grasp_loop)
- self.grasp_loop.run_forever() # type: ignore[attr-defined]
-
- self.grasp_loop_thread = threading.Thread(target=run_loop, daemon=True) # type: ignore[assignment]
- self.grasp_loop_thread.start() # type: ignore[attr-defined]
-
- # Wait for loop to start
- while self.grasp_loop is None:
- time.sleep(0.01)
-
- async def _send_grasp_request(
- self,
- points: np.ndarray, # type: ignore[type-arg]
- colors: np.ndarray | None, # type: ignore[type-arg]
- ) -> list[dict] | None: # type: ignore[type-arg]
- """Send grasp request to Dimensional Grasp server."""
- try:
- # Comprehensive client-side validation to prevent server errors
-
- # Validate points array
- if points is None:
- logger.error("Points array is None")
- return None
- if not isinstance(points, np.ndarray):
- logger.error(f"Points is not numpy array: {type(points)}")
- return None
- if points.size == 0:
- logger.error("Points array is empty")
- return None
- if len(points.shape) != 2 or points.shape[1] != 3:
- logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)")
- return None
- if points.shape[0] < 100: # Minimum points for stable grasp detection
- logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100")
- return None
-
- # Validate and prepare colors
- if colors is not None:
- if not isinstance(colors, np.ndarray):
- colors = None
- elif colors.size == 0:
- colors = None
- elif len(colors.shape) != 2 or colors.shape[1] != 3:
- colors = None
- elif colors.shape[0] != points.shape[0]:
- colors = None
-
- # If no valid colors, create default colors (required by server)
- if colors is None:
- # Create default white colors for all points
- colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5
-
- # Ensure data types are correct (server expects float32)
- points = points.astype(np.float32)
- colors = colors.astype(np.float32)
-
- # Validate ranges (basic sanity checks)
- if np.any(np.isnan(points)) or np.any(np.isinf(points)):
- logger.error("Points contain NaN or Inf values")
- return None
- if np.any(np.isnan(colors)) or np.any(np.isinf(colors)):
- logger.error("Colors contain NaN or Inf values")
- return None
-
- # Clamp color values to valid range [0, 1]
- colors = np.clip(colors, 0.0, 1.0)
-
- async with websockets.connect(self.grasp_server_url) as websocket: # type: ignore[arg-type]
- request = {
- "points": points.tolist(),
- "colors": colors.tolist(), # Always send colors array
- "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], # Default workspace limits
- }
-
- await websocket.send(json.dumps(request))
-
- response = await websocket.recv()
- grasps = json.loads(response)
-
- # Handle server response validation
- if isinstance(grasps, dict) and "error" in grasps:
- logger.error(f"Server returned error: {grasps['error']}")
- return None
- elif isinstance(grasps, int | float) and grasps == 0:
- return None
- elif not isinstance(grasps, list):
- logger.error(
- f"Server returned unexpected response type: {type(grasps)}, value: {grasps}"
- )
- return None
- elif len(grasps) == 0:
- return None
-
- converted_grasps = self._convert_grasp_format(grasps)
- with self.grasp_lock:
- self.latest_grasps = converted_grasps
- self.grasps_consumed = False # Reset consumed flag
-
- # Emit to reactive stream
- self.grasps_subject.on_next(self.latest_grasps)
-
- return converted_grasps
- except websockets.exceptions.ConnectionClosed as e:
- logger.error(f"WebSocket connection closed: {e}")
- except websockets.exceptions.WebSocketException as e:
- logger.error(f"WebSocket error: {e}")
- except json.JSONDecodeError as e:
- logger.error(f"Failed to parse server response as JSON: {e}")
- except Exception as e:
- logger.error(f"Error requesting grasps: {e}")
-
- return None
-
- def request_scene_grasps(self, objects: list[dict]) -> asyncio.Task | None: # type: ignore[type-arg]
- """Request grasps for entire scene by combining all object point clouds."""
- if not self.grasp_loop or not objects:
- return None
-
- all_points = []
- all_colors = []
- valid_objects = 0
-
- for _i, obj in enumerate(objects):
- # Validate point cloud data
- if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None:
- continue
-
- points = obj["point_cloud_numpy"]
- if not isinstance(points, np.ndarray) or points.size == 0:
- continue
-
- # Ensure points have correct shape (N, 3)
- if len(points.shape) != 2 or points.shape[1] != 3:
- continue
-
- # Validate colors if present
- colors = None
- if "colors_numpy" in obj and obj["colors_numpy"] is not None:
- colors = obj["colors_numpy"]
- if isinstance(colors, np.ndarray) and colors.size > 0:
- # Ensure colors match points count and have correct shape
- if colors.shape[0] != points.shape[0]:
- colors = None # Ignore colors for this object
- elif len(colors.shape) != 2 or colors.shape[1] != 3:
- colors = None # Ignore colors for this object
-
- all_points.append(points)
- if colors is not None:
- all_colors.append(colors)
- valid_objects += 1
-
- if not all_points:
- return None
-
- try:
- combined_points = np.vstack(all_points)
-
- # Only combine colors if ALL objects have valid colors
- combined_colors = None
- if len(all_colors) == valid_objects and len(all_colors) > 0:
- combined_colors = np.vstack(all_colors)
-
- # Validate final combined data
- if combined_points.size == 0:
- logger.warning("Combined point cloud is empty")
- return None
-
- if combined_colors is not None and combined_colors.shape[0] != combined_points.shape[0]:
- logger.warning(
- f"Color/point count mismatch: {combined_colors.shape[0]} colors vs {combined_points.shape[0]} points, dropping colors"
- )
- combined_colors = None
-
- except Exception as e:
- logger.error(f"Failed to combine point clouds: {e}")
- return None
-
- try:
- # Check if there's already a grasp task running
- if hasattr(self, "grasp_task") and self.grasp_task and not self.grasp_task.done():
- return self.grasp_task
-
- task = asyncio.run_coroutine_threadsafe(
- self._send_grasp_request(combined_points, combined_colors), self.grasp_loop
- )
-
- self.grasp_task = task
- return task
- except Exception:
- logger.warning("Failed to create grasp task")
- return None
-
- def get_latest_grasps(self, timeout: float = 5.0) -> list[dict] | None: # type: ignore[type-arg]
- """Get latest grasp results, waiting for new ones if current ones have been consumed."""
- # Mark current grasps as consumed and get a reference
- with self.grasp_lock:
- current_grasps = self.latest_grasps
- self.grasps_consumed = True
-
- # If we already have grasps and they haven't been consumed, return them
- if current_grasps is not None and not getattr(self, "grasps_consumed", False):
- return current_grasps
-
- # Wait for new grasps
- start_time = time.time()
- while time.time() - start_time < timeout:
- with self.grasp_lock:
- # Check if we have new grasps (different from what we marked as consumed)
- if self.latest_grasps is not None and not getattr(self, "grasps_consumed", False):
- return self.latest_grasps
- time.sleep(0.1) # Check every 100ms
-
- return None # Timeout reached
-
- def clear_grasps(self) -> None:
- """Clear all stored grasp results."""
- with self.grasp_lock:
- self.latest_grasps = []
-
- def _prepare_colors(self, colors: np.ndarray | None) -> np.ndarray | None: # type: ignore[type-arg]
- """Prepare colors array, converting from various formats if needed."""
- if colors is None:
- return None
-
- if colors.max() > 1.0:
- colors = colors / 255.0
-
- return colors
-
- def _convert_grasp_format(self, grasps: list[dict]) -> list[dict]: # type: ignore[type-arg]
- """Convert Grasp format to our visualization format."""
- converted = []
-
- for i, grasp in enumerate(grasps):
- rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3)))
- euler_angles = self._rotation_matrix_to_euler(rotation_matrix)
-
- converted_grasp = {
- "id": f"grasp_{i}",
- "score": grasp.get("score", 0.0),
- "width": grasp.get("width", 0.0),
- "height": grasp.get("height", 0.0),
- "depth": grasp.get("depth", 0.0),
- "translation": grasp.get("translation", [0, 0, 0]),
- "rotation_matrix": rotation_matrix.tolist(),
- "euler_angles": euler_angles,
- }
- converted.append(converted_grasp)
-
- converted.sort(key=lambda x: x["score"], reverse=True)
-
- return converted
-
- def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> dict[str, float]: # type: ignore[type-arg]
- """Convert rotation matrix to Euler angles (in radians)."""
- sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2)
-
- singular = sy < 1e-6
-
- if not singular:
- x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2])
- y = np.arctan2(-rotation_matrix[2, 0], sy)
- z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0])
- else:
- x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1])
- y = np.arctan2(-rotation_matrix[2, 0], sy)
- z = 0
-
- return {"roll": x, "pitch": y, "yaw": z}
-
- def cleanup(self) -> None:
- """Clean up resources."""
- if hasattr(self.detector, "cleanup"):
- self.detector.cleanup()
-
- if self.grasp_loop and self.grasp_loop_thread:
- self.grasp_loop.call_soon_threadsafe(self.grasp_loop.stop)
- self.grasp_loop_thread.join(timeout=1.0)
-
- if hasattr(self.pointcloud_filter, "cleanup"):
- self.pointcloud_filter.cleanup()
- logger.info("ManipulationPipeline cleaned up")
diff --git a/dimos/manipulation/manip_aio_processer.py b/dimos/manipulation/manip_aio_processer.py
deleted file mode 100644
index 71ed42bff3..0000000000
--- a/dimos/manipulation/manip_aio_processer.py
+++ /dev/null
@@ -1,422 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Sequential manipulation processor for single-frame processing without reactive streams.
-"""
-
-import time
-from typing import Any
-
-import cv2
-import numpy as np
-
-from dimos.perception.common.utils import (
- colorize_depth,
- combine_object_data,
- detection_results_to_object_data,
-)
-from dimos.perception.detection2d.detic_2d_det import ( # type: ignore[import-not-found, import-untyped]
- Detic2DDetector,
-)
-from dimos.perception.grasp_generation.grasp_generation import HostedGraspGenerator
-from dimos.perception.grasp_generation.utils import create_grasp_overlay
-from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering
-from dimos.perception.pointcloud.utils import (
- create_point_cloud_overlay_visualization,
- extract_and_cluster_misc_points,
- overlay_point_clouds_on_image,
-)
-from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class ManipulationProcessor:
- """
- Sequential manipulation processor for single-frame processing.
-
- Processes RGB-D frames through object detection, point cloud filtering,
- and grasp generation in a single thread without reactive streams.
- """
-
- def __init__(
- self,
- camera_intrinsics: list[float], # [fx, fy, cx, cy]
- min_confidence: float = 0.6,
- max_objects: int = 20,
- vocabulary: str | None = None,
- enable_grasp_generation: bool = False,
- grasp_server_url: str | None = None, # Required when enable_grasp_generation=True
- enable_segmentation: bool = True,
- ) -> None:
- """
- Initialize the manipulation processor.
-
- Args:
- camera_intrinsics: [fx, fy, cx, cy] camera parameters
- min_confidence: Minimum detection confidence threshold
- max_objects: Maximum number of objects to process
- vocabulary: Optional vocabulary for Detic detector
- enable_grasp_generation: Whether to enable grasp generation
- grasp_server_url: WebSocket URL for Dimensional Grasp server (required when enable_grasp_generation=True)
- enable_segmentation: Whether to enable semantic segmentation
- segmentation_model: Segmentation model to use (SAM 2 or FastSAM)
- """
- self.camera_intrinsics = camera_intrinsics
- self.min_confidence = min_confidence
- self.max_objects = max_objects
- self.enable_grasp_generation = enable_grasp_generation
- self.grasp_server_url = grasp_server_url
- self.enable_segmentation = enable_segmentation
-
- # Validate grasp generation requirements
- if enable_grasp_generation and not grasp_server_url:
- raise ValueError("grasp_server_url is required when enable_grasp_generation=True")
-
- # Initialize object detector
- self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence)
-
- # Initialize point cloud processor
- self.pointcloud_filter = PointcloudFiltering(
- color_intrinsics=camera_intrinsics,
- depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics
- max_num_objects=max_objects,
- )
-
- # Initialize semantic segmentation
- self.segmenter = None
- if self.enable_segmentation:
- self.segmenter = Sam2DSegmenter(
- use_tracker=False, # Disable tracker for simple segmentation
- use_analyzer=False, # Disable analyzer for simple segmentation
- )
-
- # Initialize grasp generator if enabled
- self.grasp_generator = None
- if self.enable_grasp_generation:
- try:
- self.grasp_generator = HostedGraspGenerator(server_url=grasp_server_url) # type: ignore[arg-type]
- logger.info("Hosted grasp generator initialized successfully")
- except Exception as e:
- logger.error(f"Failed to initialize hosted grasp generator: {e}")
- self.grasp_generator = None
- self.enable_grasp_generation = False
-
- logger.info(
- f"Initialized ManipulationProcessor with confidence={min_confidence}, "
- f"grasp_generation={enable_grasp_generation}"
- )
-
- def process_frame(
- self,
- rgb_image: np.ndarray, # type: ignore[type-arg]
- depth_image: np.ndarray, # type: ignore[type-arg]
- generate_grasps: bool | None = None,
- ) -> dict[str, Any]:
- """
- Process a single RGB-D frame through the complete pipeline.
-
- Args:
- rgb_image: RGB image (H, W, 3)
- depth_image: Depth image (H, W) in meters
- generate_grasps: Override grasp generation setting for this frame
-
- Returns:
- Dictionary containing:
- - detection_viz: Visualization of object detection
- - pointcloud_viz: Visualization of point cloud overlay
- - segmentation_viz: Visualization of semantic segmentation (if enabled)
- - detection2d_objects: Raw detection results as ObjectData
- - segmentation2d_objects: Raw segmentation results as ObjectData (if enabled)
- - detected_objects: Detection (Object Detection) objects with point clouds filtered
- - all_objects: Combined objects with intelligent duplicate removal
- - full_pointcloud: Complete scene point cloud (if point cloud processing enabled)
- - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN)
- - misc_voxel_grid: Open3D voxel grid approximating all misc/background points
- - misc_pointcloud_viz: Visualization of misc/background cluster overlay
- - grasps: Grasp results (list of dictionaries, if enabled)
- - grasp_overlay: Grasp visualization overlay (if enabled)
- - processing_time: Total processing time
- """
- start_time = time.time()
- results = {}
-
- try:
- # Step 1: Object Detection
- step_start = time.time()
- detection_results = self.run_object_detection(rgb_image)
- results["detection2d_objects"] = detection_results.get("objects", [])
- results["detection_viz"] = detection_results.get("viz_frame")
- detection_time = time.time() - step_start
-
- # Step 2: Semantic Segmentation (if enabled)
- segmentation_time = 0
- if self.enable_segmentation:
- step_start = time.time()
- segmentation_results = self.run_segmentation(rgb_image)
- results["segmentation2d_objects"] = segmentation_results.get("objects", [])
- results["segmentation_viz"] = segmentation_results.get("viz_frame")
- segmentation_time = time.time() - step_start # type: ignore[assignment]
-
- # Step 3: Point Cloud Processing
- pointcloud_time = 0
- detection2d_objects = results.get("detection2d_objects", [])
- segmentation2d_objects = results.get("segmentation2d_objects", [])
-
- # Process detection objects if available
- detected_objects = []
- if detection2d_objects:
- step_start = time.time()
- detected_objects = self.run_pointcloud_filtering(
- rgb_image, depth_image, detection2d_objects
- )
- pointcloud_time += time.time() - step_start # type: ignore[assignment]
-
- # Process segmentation objects if available
- segmentation_filtered_objects = []
- if segmentation2d_objects:
- step_start = time.time()
- segmentation_filtered_objects = self.run_pointcloud_filtering(
- rgb_image, depth_image, segmentation2d_objects
- )
- pointcloud_time += time.time() - step_start # type: ignore[assignment]
-
- # Combine all objects using intelligent duplicate removal
- all_objects = combine_object_data(
- detected_objects, # type: ignore[arg-type]
- segmentation_filtered_objects, # type: ignore[arg-type]
- overlap_threshold=0.8,
- )
-
- # Get full point cloud
- full_pcd = self.pointcloud_filter.get_full_point_cloud()
-
- # Extract misc/background points and create voxel grid
- misc_start = time.time()
- misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points(
- full_pcd,
- all_objects, # type: ignore[arg-type]
- eps=0.03,
- min_points=100,
- enable_filtering=True,
- voxel_size=0.02,
- )
- misc_time = time.time() - misc_start
-
- # Store results
- results.update(
- {
- "detected_objects": detected_objects,
- "all_objects": all_objects,
- "full_pointcloud": full_pcd,
- "misc_clusters": misc_clusters,
- "misc_voxel_grid": misc_voxel_grid,
- }
- )
-
- # Create point cloud visualizations
- base_image = colorize_depth(depth_image, max_depth=10.0)
-
- # Create visualizations
- results["pointcloud_viz"] = (
- create_point_cloud_overlay_visualization(
- base_image=base_image, # type: ignore[arg-type]
- objects=all_objects, # type: ignore[arg-type]
- intrinsics=self.camera_intrinsics, # type: ignore[arg-type]
- )
- if all_objects
- else base_image
- )
-
- results["detected_pointcloud_viz"] = (
- create_point_cloud_overlay_visualization(
- base_image=base_image, # type: ignore[arg-type]
- objects=detected_objects,
- intrinsics=self.camera_intrinsics, # type: ignore[arg-type]
- )
- if detected_objects
- else base_image
- )
-
- if misc_clusters:
- # Generate consistent colors for clusters
- cluster_colors = [
- tuple((np.random.RandomState(i + 100).rand(3) * 255).astype(int))
- for i in range(len(misc_clusters))
- ]
- results["misc_pointcloud_viz"] = overlay_point_clouds_on_image(
- base_image=base_image, # type: ignore[arg-type]
- point_clouds=misc_clusters,
- camera_intrinsics=self.camera_intrinsics,
- colors=cluster_colors,
- point_size=2,
- alpha=0.6,
- )
- else:
- results["misc_pointcloud_viz"] = base_image
-
- # Step 4: Grasp Generation (if enabled)
- should_generate_grasps = (
- generate_grasps if generate_grasps is not None else self.enable_grasp_generation
- )
-
- if should_generate_grasps and all_objects and full_pcd:
- grasps = self.run_grasp_generation(all_objects, full_pcd) # type: ignore[arg-type]
- results["grasps"] = grasps
- if grasps:
- results["grasp_overlay"] = create_grasp_overlay(
- rgb_image, grasps, self.camera_intrinsics
- )
-
- except Exception as e:
- logger.error(f"Error processing frame: {e}")
- results["error"] = str(e)
-
- # Add timing information
- total_time = time.time() - start_time
- results.update(
- {
- "processing_time": total_time,
- "timing_breakdown": {
- "detection": detection_time if "detection_time" in locals() else 0,
- "segmentation": segmentation_time if "segmentation_time" in locals() else 0,
- "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0,
- "misc_extraction": misc_time if "misc_time" in locals() else 0,
- "total": total_time,
- },
- }
- )
-
- return results
-
- def run_object_detection(self, rgb_image: np.ndarray) -> dict[str, Any]: # type: ignore[type-arg]
- """Run object detection on RGB image."""
- try:
- # Convert RGB to BGR for Detic detector
- bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
-
- # Use process_image method from Detic detector
- bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image(
- bgr_image
- )
-
- # Convert to ObjectData format using utility function
- objects = detection_results_to_object_data(
- bboxes=bboxes,
- track_ids=track_ids,
- class_ids=class_ids,
- confidences=confidences,
- names=names,
- masks=masks,
- source="detection",
- )
-
- # Create visualization using detector's built-in method
- viz_frame = self.detector.visualize_results(
- rgb_image, bboxes, track_ids, class_ids, confidences, names
- )
-
- return {"objects": objects, "viz_frame": viz_frame}
-
- except Exception as e:
- logger.error(f"Object detection failed: {e}")
- return {"objects": [], "viz_frame": rgb_image.copy()}
-
- def run_pointcloud_filtering(
- self,
- rgb_image: np.ndarray, # type: ignore[type-arg]
- depth_image: np.ndarray, # type: ignore[type-arg]
- objects: list[dict], # type: ignore[type-arg]
- ) -> list[dict]: # type: ignore[type-arg]
- """Run point cloud filtering on detected objects."""
- try:
- filtered_objects = self.pointcloud_filter.process_images(
- rgb_image,
- depth_image,
- objects, # type: ignore[arg-type]
- )
- return filtered_objects if filtered_objects else [] # type: ignore[return-value]
- except Exception as e:
- logger.error(f"Point cloud filtering failed: {e}")
- return []
-
- def run_segmentation(self, rgb_image: np.ndarray) -> dict[str, Any]: # type: ignore[type-arg]
- """Run semantic segmentation on RGB image."""
- if not self.segmenter:
- return {"objects": [], "viz_frame": rgb_image.copy()}
-
- try:
- # Convert RGB to BGR for segmenter
- bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
-
- # Get segmentation results
- masks, bboxes, track_ids, probs, names = self.segmenter.process_image(bgr_image) # type: ignore[no-untyped-call]
-
- # Convert to ObjectData format using utility function
- objects = detection_results_to_object_data(
- bboxes=bboxes,
- track_ids=track_ids,
- class_ids=list(range(len(bboxes))), # Use indices as class IDs for segmentation
- confidences=probs,
- names=names,
- masks=masks,
- source="segmentation",
- )
-
- # Create visualization
- if masks:
- viz_bgr = self.segmenter.visualize_results(
- bgr_image, masks, bboxes, track_ids, probs, names
- )
- # Convert back to RGB
- viz_frame = cv2.cvtColor(viz_bgr, cv2.COLOR_BGR2RGB)
- else:
- viz_frame = rgb_image.copy()
-
- return {"objects": objects, "viz_frame": viz_frame}
-
- except Exception as e:
- logger.error(f"Segmentation failed: {e}")
- return {"objects": [], "viz_frame": rgb_image.copy()}
-
- def run_grasp_generation(self, filtered_objects: list[dict], full_pcd) -> list[dict] | None: # type: ignore[no-untyped-def, type-arg]
- """Run grasp generation using the configured generator."""
- if not self.grasp_generator:
- logger.warning("Grasp generation requested but no generator available")
- return None
-
- try:
- # Generate grasps using the configured generator
- grasps = self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) # type: ignore[arg-type]
-
- # Return parsed results directly (list of grasp dictionaries)
- return grasps
-
- except Exception as e:
- logger.error(f"Grasp generation failed: {e}")
- return None
-
- def cleanup(self) -> None:
- """Clean up resources."""
- if hasattr(self.detector, "cleanup"):
- self.detector.cleanup()
- if hasattr(self.pointcloud_filter, "cleanup"):
- self.pointcloud_filter.cleanup()
- if self.segmenter and hasattr(self.segmenter, "cleanup"):
- self.segmenter.cleanup()
- if self.grasp_generator and hasattr(self.grasp_generator, "cleanup"):
- self.grasp_generator.cleanup()
- logger.info("ManipulationProcessor cleaned up")
diff --git a/dimos/manipulation/manipulation_interface.py b/dimos/manipulation/manipulation_interface.py
index edeb99c0f0..10e71fbc66 100644
--- a/dimos/manipulation/manipulation_interface.py
+++ b/dimos/manipulation/manipulation_interface.py
@@ -26,7 +26,6 @@
from dimos.manipulation.manipulation_history import (
ManipulationHistory,
)
-from dimos.perception.object_detection_stream import ObjectDetectionStream
from dimos.types.manipulation import (
AbstractConstraint,
ManipulationTask,
@@ -53,7 +52,7 @@ def __init__(
self,
output_dir: str,
new_memory: bool = False,
- perception_stream: ObjectDetectionStream = None, # type: ignore[assignment]
+ perception_stream: Any = None,
) -> None:
"""
Initialize a new ManipulationInterface instance.
diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py
deleted file mode 100644
index fca085df8c..0000000000
--- a/dimos/manipulation/visual_servoing/detection3d.py
+++ /dev/null
@@ -1,302 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Real-time 3D object detection processor that extracts object poses from RGB-D data.
-"""
-
-import cv2
-from dimos_lcm.vision_msgs import (
- BoundingBox2D,
- BoundingBox3D,
- Detection2D,
- Detection3D,
- ObjectHypothesis,
- ObjectHypothesisWithPose,
- Point2D,
- Pose2D,
-)
-import numpy as np
-
-from dimos.manipulation.visual_servoing.utils import (
- estimate_object_depth,
- transform_pose,
- visualize_detections_3d,
-)
-from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3
-from dimos.msgs.std_msgs import Header
-from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray
-from dimos.perception.common.utils import bbox2d_to_corners
-from dimos.perception.detection2d.utils import calculate_object_size_from_bbox
-from dimos.perception.pointcloud.utils import extract_centroids_from_masks
-from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class Detection3DProcessor:
- """
- Real-time 3D detection processor optimized for speed.
-
- Uses Sam (FastSAM) for segmentation and mask generation, then extracts
- 3D centroids from depth data.
- """
-
- def __init__(
- self,
- camera_intrinsics: list[float], # [fx, fy, cx, cy]
- min_confidence: float = 0.6,
- min_points: int = 30,
- max_depth: float = 1.0,
- max_object_size: float = 0.15,
- ) -> None:
- """
- Initialize the real-time 3D detection processor.
-
- Args:
- camera_intrinsics: [fx, fy, cx, cy] camera parameters
- min_confidence: Minimum detection confidence threshold
- min_points: Minimum 3D points required for valid detection
- max_depth: Maximum valid depth in meters
- """
- self.camera_intrinsics = camera_intrinsics
- self.min_points = min_points
- self.max_depth = max_depth
- self.max_object_size = max_object_size
-
- # Initialize Sam segmenter with tracking enabled but analysis disabled
- self.detector = Sam2DSegmenter(
- use_tracker=False,
- use_analyzer=False,
- use_filtering=True,
- )
-
- self.min_confidence = min_confidence
-
- logger.info(
- f"Initialized Detection3DProcessor with Sam segmenter, confidence={min_confidence}, "
- f"min_points={min_points}, max_depth={max_depth}m, max_object_size={max_object_size}m"
- )
-
- def process_frame(
- self,
- rgb_image: np.ndarray, # type: ignore[type-arg]
- depth_image: np.ndarray, # type: ignore[type-arg]
- transform: np.ndarray | None = None, # type: ignore[type-arg]
- ) -> tuple[Detection3DArray, Detection2DArray]:
- """
- Process a single RGB-D frame to extract 3D object detections.
-
- Args:
- rgb_image: RGB image (H, W, 3)
- depth_image: Depth image (H, W) in meters
- transform: Optional 4x4 transformation matrix to transform objects from camera frame to desired frame
-
- Returns:
- Tuple of (Detection3DArray, Detection2DArray) with 3D and 2D information
- """
-
- # Convert RGB to BGR for Sam (OpenCV format)
- bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
-
- # Run Sam segmentation with tracking
- masks, bboxes, track_ids, probs, names = self.detector.process_image(bgr_image) # type: ignore[no-untyped-call]
-
- if not masks or len(masks) == 0:
- return Detection3DArray(
- detections_length=0, header=Header(), detections=[]
- ), Detection2DArray(detections_length=0, header=Header(), detections=[])
-
- # Convert CUDA tensors to numpy arrays if needed
- numpy_masks = []
- for mask in masks:
- if hasattr(mask, "cpu"): # PyTorch tensor
- numpy_masks.append(mask.cpu().numpy())
- else: # Already numpy array
- numpy_masks.append(mask)
-
- # Extract 3D centroids from masks
- poses = extract_centroids_from_masks(
- rgb_image=rgb_image,
- depth_image=depth_image,
- masks=numpy_masks,
- camera_intrinsics=self.camera_intrinsics,
- )
-
- detections_3d = []
- detections_2d = []
- pose_dict = {p["mask_idx"]: p for p in poses if p["centroid"][2] < self.max_depth}
-
- for i, (bbox, name, prob, track_id) in enumerate(
- zip(bboxes, names, probs, track_ids, strict=False)
- ):
- if i not in pose_dict:
- continue
-
- pose = pose_dict[i]
- obj_cam_pos = pose["centroid"]
-
- if obj_cam_pos[2] > self.max_depth:
- continue
-
- # Calculate object size from bbox and depth
- width_m, height_m = calculate_object_size_from_bbox(
- bbox, obj_cam_pos[2], self.camera_intrinsics
- )
-
- # Calculate depth dimension using segmentation mask
- depth_m = estimate_object_depth(
- depth_image, numpy_masks[i] if i < len(numpy_masks) else None, bbox
- )
-
- size_x = max(width_m, 0.01) # Minimum 1cm width
- size_y = max(height_m, 0.01) # Minimum 1cm height
- size_z = max(depth_m, 0.01) # Minimum 1cm depth
-
- if min(size_x, size_y, size_z) > self.max_object_size:
- continue
-
- # Transform to desired frame if transform matrix is provided
- if transform is not None:
- # Get orientation as euler angles, default to no rotation if not available
- obj_cam_orientation = pose.get(
- "rotation", np.array([0.0, 0.0, 0.0])
- ) # Default to no rotation
- transformed_pose = transform_pose(
- obj_cam_pos, obj_cam_orientation, transform, to_robot=True
- )
- center_pose = transformed_pose
- else:
- # If no transform, use camera coordinates
- center_pose = Pose(
- position=Vector3(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]),
- orientation=Quaternion(0.0, 0.0, 0.0, 1.0), # Default orientation
- )
-
- # Create Detection3D object
- detection = Detection3D(
- results_length=1,
- header=Header(), # Empty header
- results=[
- ObjectHypothesisWithPose(
- hypothesis=ObjectHypothesis(class_id=name, score=float(prob))
- )
- ],
- bbox=BoundingBox3D(center=center_pose, size=Vector3(size_x, size_y, size_z)),
- id=str(track_id),
- )
-
- detections_3d.append(detection)
-
- # Create corresponding Detection2D
- x1, y1, x2, y2 = bbox
- center_x = (x1 + x2) / 2.0
- center_y = (y1 + y2) / 2.0
- width = x2 - x1
- height = y2 - y1
-
- detection_2d = Detection2D(
- results_length=1,
- header=Header(),
- results=[
- ObjectHypothesisWithPose(
- hypothesis=ObjectHypothesis(class_id=name, score=float(prob))
- )
- ],
- bbox=BoundingBox2D(
- center=Pose2D(position=Point2D(center_x, center_y), theta=0.0),
- size_x=float(width),
- size_y=float(height),
- ),
- id=str(track_id),
- )
- detections_2d.append(detection_2d)
-
- # Create and return both arrays
- return (
- Detection3DArray(
- detections_length=len(detections_3d), header=Header(), detections=detections_3d
- ),
- Detection2DArray(
- detections_length=len(detections_2d), header=Header(), detections=detections_2d
- ),
- )
-
- def visualize_detections(
- self,
- rgb_image: np.ndarray, # type: ignore[type-arg]
- detections_3d: list[Detection3D],
- detections_2d: list[Detection2D],
- show_coordinates: bool = True,
- ) -> np.ndarray: # type: ignore[type-arg]
- """
- Visualize detections with 3D position overlay next to bounding boxes.
-
- Args:
- rgb_image: Original RGB image
- detections_3d: List of Detection3D objects
- detections_2d: List of Detection2D objects (must be 1:1 correspondence)
- show_coordinates: Whether to show 3D coordinates
-
- Returns:
- Visualization image
- """
- # Extract 2D bboxes from Detection2D objects
-
- bboxes_2d = []
- for det_2d in detections_2d:
- if det_2d.bbox:
- x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox)
- bboxes_2d.append([x1, y1, x2, y2])
-
- return visualize_detections_3d(rgb_image, detections_3d, show_coordinates, bboxes_2d)
-
- def get_closest_detection(
- self, detections: list[Detection3D], class_filter: str | None = None
- ) -> Detection3D | None:
- """
- Get the closest detection with valid 3D data.
-
- Args:
- detections: List of Detection3D objects
- class_filter: Optional class name to filter by
-
- Returns:
- Closest Detection3D or None
- """
- valid_detections = []
- for d in detections:
- # Check if has valid bbox center position
- if d.bbox and d.bbox.center and d.bbox.center.position:
- # Check class filter if specified
- if class_filter is None or (
- d.results_length > 0 and d.results[0].hypothesis.class_id == class_filter
- ):
- valid_detections.append(d)
-
- if not valid_detections:
- return None
-
- # Sort by depth (Z coordinate)
- def get_z_coord(d): # type: ignore[no-untyped-def]
- return abs(d.bbox.center.position.z)
-
- return min(valid_detections, key=get_z_coord)
-
- def cleanup(self) -> None:
- """Clean up resources."""
- if hasattr(self.detector, "cleanup"):
- self.detector.cleanup()
- logger.info("Detection3DProcessor cleaned up")
diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py
deleted file mode 100644
index 088db9eb26..0000000000
--- a/dimos/manipulation/visual_servoing/manipulation_module.py
+++ /dev/null
@@ -1,951 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Manipulation module for robotic grasping with visual servoing.
-Handles grasping logic, state machine, and hardware coordination as a Dimos module.
-"""
-
-from collections import deque
-from enum import Enum
-import threading
-import time
-from typing import Any
-
-import cv2
-from dimos_lcm.sensor_msgs import CameraInfo
-import numpy as np
-from reactivex.disposable import Disposable
-
-from dimos.core import In, Module, Out, rpc
-from dimos.hardware.manipulators.piper.piper_arm import ( # type: ignore[import-not-found, import-untyped]
- PiperArm,
-)
-from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor
-from dimos.manipulation.visual_servoing.pbvs import PBVS
-from dimos.manipulation.visual_servoing.utils import (
- create_manipulation_visualization,
- is_target_reached,
- select_points_from_depth,
- transform_points_3d,
- update_target_grasp_pose,
-)
-from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3
-from dimos.msgs.sensor_msgs import Image
-from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray
-from dimos.perception.common.utils import find_clicked_detection
-from dimos.utils.logging_config import setup_logger
-from dimos.utils.transform_utils import (
- compose_transforms,
- create_transform_from_6dof,
- matrix_to_pose,
- pose_to_matrix,
-)
-
-logger = setup_logger()
-
-
-class GraspStage(Enum):
- """Enum for different grasp stages."""
-
- IDLE = "idle"
- PRE_GRASP = "pre_grasp"
- GRASP = "grasp"
- CLOSE_AND_RETRACT = "close_and_retract"
- PLACE = "place"
- RETRACT = "retract"
-
-
-class Feedback:
- """Feedback data containing state information about the manipulation process."""
-
- def __init__(
- self,
- grasp_stage: GraspStage,
- target_tracked: bool,
- current_executed_pose: Pose | None = None,
- current_ee_pose: Pose | None = None,
- current_camera_pose: Pose | None = None,
- target_pose: Pose | None = None,
- waiting_for_reach: bool = False,
- success: bool | None = None,
- ) -> None:
- self.grasp_stage = grasp_stage
- self.target_tracked = target_tracked
- self.current_executed_pose = current_executed_pose
- self.current_ee_pose = current_ee_pose
- self.current_camera_pose = current_camera_pose
- self.target_pose = target_pose
- self.waiting_for_reach = waiting_for_reach
- self.success = success
-
-
-class ManipulationModule(Module):
- """
- Manipulation module for visual servoing and grasping.
-
- Subscribes to:
- - ZED RGB images
- - ZED depth images
- - ZED camera info
-
- Publishes:
- - Visualization images
-
- RPC methods:
- - handle_keyboard_command: Process keyboard input
- - pick_and_place: Execute pick and place task
- """
-
- # LCM inputs
- rgb_image: In[Image]
- depth_image: In[Image]
- camera_info: In[CameraInfo]
-
- # LCM outputs
- viz_image: Out[Image]
-
- def __init__( # type: ignore[no-untyped-def]
- self,
- ee_to_camera_6dof: list | None = None, # type: ignore[type-arg]
- **kwargs,
- ) -> None:
- """
- Initialize manipulation module.
-
- Args:
- ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians
- workspace_min_radius: Minimum workspace radius in meters
- workspace_max_radius: Maximum workspace radius in meters
- min_grasp_pitch_degrees: Minimum grasp pitch angle (at max radius)
- max_grasp_pitch_degrees: Maximum grasp pitch angle (at min radius)
- """
- super().__init__(**kwargs)
-
- self.arm = PiperArm()
-
- if ee_to_camera_6dof is None:
- ee_to_camera_6dof = [-0.065, 0.03, -0.095, 0.0, -1.57, 0.0]
- pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2])
- rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5])
- self.T_ee_to_camera = create_transform_from_6dof(pos, rot)
-
- self.camera_intrinsics = None
- self.detector = None
- self.pbvs = None
-
- # Control state
- self.last_valid_target = None
- self.waiting_for_reach = False
- self.current_executed_pose = None # Track the actual pose sent to arm
- self.target_updated = False
- self.waiting_start_time = None
- self.reach_pose_timeout = 20.0
-
- # Grasp parameters
- self.grasp_width_offset = 0.03
- self.pregrasp_distance = 0.25
- self.grasp_distance_range = 0.03
- self.grasp_close_delay = 2.0
- self.grasp_reached_time = None
- self.gripper_max_opening = 0.07
-
- # Workspace limits and dynamic pitch parameters
- self.workspace_min_radius = 0.2
- self.workspace_max_radius = 0.75
- self.min_grasp_pitch_degrees = 5.0
- self.max_grasp_pitch_degrees = 60.0
-
- # Grasp stage tracking
- self.grasp_stage = GraspStage.IDLE
-
- # Pose stabilization tracking
- self.pose_history_size = 4
- self.pose_stabilization_threshold = 0.01
- self.stabilization_timeout = 25.0
- self.stabilization_start_time = None
- self.reached_poses = deque(maxlen=self.pose_history_size) # type: ignore[var-annotated]
- self.adjustment_count = 0
-
- # Pose reachability tracking
- self.ee_pose_history = deque(maxlen=20) # type: ignore[var-annotated] # Keep history of EE poses
- self.stuck_pose_threshold = 0.001 # 1mm movement threshold
- self.stuck_pose_adjustment_degrees = 5.0
- self.stuck_count = 0
- self.max_stuck_reattempts = 7
-
- # State for visualization
- self.current_visualization = None
- self.last_detection_3d_array = None
- self.last_detection_2d_array = None
-
- # Grasp result and task tracking
- self.pick_success = None
- self.final_pregrasp_pose = None
- self.task_failed = False
- self.overall_success = None
-
- # Task control
- self.task_running = False
- self.task_thread = None
- self.stop_event = threading.Event()
-
- # Latest sensor data
- self.latest_rgb = None
- self.latest_depth = None
- self.latest_camera_info = None
-
- # Target selection
- self.target_click = None
-
- # Place target position and object info
- self.home_pose = Pose(
- position=Vector3(0.0, 0.0, 0.0), orientation=Quaternion(0.0, 0.0, 0.0, 1.0)
- )
- self.place_target_position = None
- self.target_object_height = None
- self.retract_distance = 0.12
- self.place_pose = None
- self.retract_pose = None
- self.arm.gotoObserve()
-
- @rpc
- def start(self) -> None:
- """Start the manipulation module."""
-
- unsub = self.rgb_image.subscribe(self._on_rgb_image)
- self._disposables.add(Disposable(unsub))
-
- unsub = self.depth_image.subscribe(self._on_depth_image)
- self._disposables.add(Disposable(unsub))
-
- unsub = self.camera_info.subscribe(self._on_camera_info)
- self._disposables.add(Disposable(unsub))
-
- logger.info("Manipulation module started")
-
- @rpc
- def stop(self) -> None:
- """Stop the manipulation module."""
- # Stop any running task
- self.stop_event.set()
- if self.task_thread and self.task_thread.is_alive():
- self.task_thread.join(timeout=5.0)
-
- self.reset_to_idle()
-
- if self.detector and hasattr(self.detector, "cleanup"):
- self.detector.cleanup()
- self.arm.disable()
-
- logger.info("Manipulation module stopped")
-
- def _on_rgb_image(self, msg: Image) -> None:
- """Handle RGB image messages."""
- try:
- self.latest_rgb = msg.data
- except Exception as e:
- logger.error(f"Error processing RGB image: {e}")
-
- def _on_depth_image(self, msg: Image) -> None:
- """Handle depth image messages."""
- try:
- self.latest_depth = msg.data
- except Exception as e:
- logger.error(f"Error processing depth image: {e}")
-
- def _on_camera_info(self, msg: CameraInfo) -> None:
- """Handle camera info messages."""
- try:
- self.camera_intrinsics = [msg.K[0], msg.K[4], msg.K[2], msg.K[5]] # type: ignore[assignment]
-
- if self.detector is None:
- self.detector = Detection3DProcessor(self.camera_intrinsics) # type: ignore[arg-type, assignment]
- self.pbvs = PBVS() # type: ignore[assignment]
- logger.info("Initialized detection and PBVS processors")
-
- self.latest_camera_info = msg
- except Exception as e:
- logger.error(f"Error processing camera info: {e}")
-
- @rpc
- def get_single_rgb_frame(self) -> np.ndarray | None: # type: ignore[type-arg]
- """
- get the latest rgb frame from the camera
- """
- return self.latest_rgb
-
- @rpc
- def handle_keyboard_command(self, key: str) -> str:
- """
- Handle keyboard commands for robot control.
-
- Args:
- key: Keyboard key as string
-
- Returns:
- Action taken as string, or empty string if no action
- """
- key_code = ord(key) if len(key) == 1 else int(key)
-
- if key_code == ord("r"):
- self.stop_event.set()
- self.task_running = False
- self.reset_to_idle()
- return "reset"
- elif key_code == ord("s"):
- logger.info("SOFT STOP - Emergency stopping robot!")
- self.arm.softStop()
- self.stop_event.set()
- self.task_running = False
- return "stop"
- elif key_code == ord(" ") and self.pbvs and self.pbvs.target_grasp_pose:
- if self.grasp_stage == GraspStage.PRE_GRASP:
- self.set_grasp_stage(GraspStage.GRASP)
- logger.info("Executing target pose")
- return "execute"
- elif key_code == ord("g"):
- logger.info("Opening gripper")
- self.arm.release_gripper()
- return "release"
-
- return ""
-
- @rpc
- def pick_and_place(
- self,
- target_x: int | None = None,
- target_y: int | None = None,
- place_x: int | None = None,
- place_y: int | None = None,
- ) -> dict[str, Any]:
- """
- Start a pick and place task.
-
- Args:
- target_x: Optional X coordinate of target object
- target_y: Optional Y coordinate of target object
- place_x: Optional X coordinate of place location
- place_y: Optional Y coordinate of place location
-
- Returns:
- Dict with status and message
- """
- if self.task_running:
- return {"status": "error", "message": "Task already running"}
-
- if self.camera_intrinsics is None:
- return {"status": "error", "message": "Camera not initialized"}
-
- if target_x is not None and target_y is not None:
- self.target_click = (target_x, target_y)
- if place_x is not None and self.latest_depth is not None:
- points_3d_camera = select_points_from_depth(
- self.latest_depth,
- (place_x, place_y),
- self.camera_intrinsics,
- radius=10,
- )
-
- if points_3d_camera.size > 0:
- ee_pose = self.arm.get_ee_pose()
- ee_transform = pose_to_matrix(ee_pose)
- camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera)
-
- points_3d_world = transform_points_3d(
- points_3d_camera,
- camera_transform,
- to_robot=True,
- )
-
- place_position = np.mean(points_3d_world, axis=0)
- self.place_target_position = place_position
- logger.info(
- f"Place target set at position: ({place_position[0]:.3f}, {place_position[1]:.3f}, {place_position[2]:.3f})"
- )
- else:
- logger.warning("No valid depth points found at place location")
- self.place_target_position = None
- else:
- self.place_target_position = None
-
- self.task_failed = False
- self.stop_event.clear()
-
- if self.task_thread and self.task_thread.is_alive():
- self.stop_event.set()
- self.task_thread.join(timeout=1.0)
- self.task_thread = threading.Thread(target=self._run_pick_and_place, daemon=True)
- self.task_thread.start()
-
- return {"status": "started", "message": "Pick and place task started"}
-
- def _run_pick_and_place(self) -> None:
- """Run the pick and place task loop."""
- self.task_running = True
- logger.info("Starting pick and place task")
-
- try:
- while not self.stop_event.is_set():
- if self.task_failed:
- logger.error("Task failed, terminating pick and place")
- self.stop_event.set()
- break
-
- feedback = self.update()
- if feedback is None:
- time.sleep(0.01)
- continue
-
- if feedback.success is not None: # type: ignore[attr-defined]
- if feedback.success: # type: ignore[attr-defined]
- logger.info("Pick and place completed successfully!")
- else:
- logger.warning("Pick and place failed")
- self.reset_to_idle()
- self.stop_event.set()
- break
-
- time.sleep(0.01)
-
- except Exception as e:
- logger.error(f"Error in pick and place task: {e}")
- self.task_failed = True
- finally:
- self.task_running = False
- logger.info("Pick and place task ended")
-
- def set_grasp_stage(self, stage: GraspStage) -> None:
- """Set the grasp stage."""
- self.grasp_stage = stage
- logger.info(f"Grasp stage: {stage.value}")
-
- def calculate_dynamic_grasp_pitch(self, target_pose: Pose) -> float:
- """
- Calculate grasp pitch dynamically based on distance from robot base.
- Maps workspace radius to grasp pitch angle.
-
- Args:
- target_pose: Target pose
-
- Returns:
- Grasp pitch angle in degrees
- """
- # Calculate 3D distance from robot base (assumes robot at origin)
- position = target_pose.position
- distance = np.sqrt(position.x**2 + position.y**2 + position.z**2)
-
- # Clamp distance to workspace limits
- distance = np.clip(distance, self.workspace_min_radius, self.workspace_max_radius)
-
- # Linear interpolation: min_radius -> max_pitch, max_radius -> min_pitch
- # Normalized distance (0 to 1)
- normalized_dist = (distance - self.workspace_min_radius) / (
- self.workspace_max_radius - self.workspace_min_radius
- )
-
- # Inverse mapping: closer objects need higher pitch
- pitch_degrees = self.max_grasp_pitch_degrees - (
- normalized_dist * (self.max_grasp_pitch_degrees - self.min_grasp_pitch_degrees)
- )
-
- return pitch_degrees # type: ignore[no-any-return]
-
- def check_within_workspace(self, target_pose: Pose) -> bool:
- """
- Check if pose is within workspace limits and log error if not.
-
- Args:
- target_pose: Target pose to validate
-
- Returns:
- True if within workspace, False otherwise
- """
- # Calculate 3D distance from robot base
- position = target_pose.position
- distance = np.sqrt(position.x**2 + position.y**2 + position.z**2)
-
- if not (self.workspace_min_radius <= distance <= self.workspace_max_radius):
- logger.error(
- f"Target outside workspace limits: distance {distance:.3f}m not in [{self.workspace_min_radius:.2f}, {self.workspace_max_radius:.2f}]"
- )
- return False
-
- return True
-
- def _check_reach_timeout(self) -> tuple[bool, float]:
- """Check if robot has exceeded timeout while reaching pose.
-
- Returns:
- Tuple of (timed_out, time_elapsed)
- """
- if self.waiting_start_time:
- time_elapsed = time.time() - self.waiting_start_time
- if time_elapsed > self.reach_pose_timeout:
- logger.warning(
- f"Robot failed to reach pose within {self.reach_pose_timeout}s timeout"
- )
- self.task_failed = True
- self.reset_to_idle()
- return True, time_elapsed
- return False, time_elapsed
- return False, 0.0
-
- def _check_if_stuck(self) -> bool:
- """
- Check if robot is stuck by analyzing pose history.
-
- Returns:
- Tuple of (is_stuck, max_std_dev_mm)
- """
- if len(self.ee_pose_history) < self.ee_pose_history.maxlen: # type: ignore[operator]
- return False
-
- # Extract positions from pose history
- positions = np.array(
- [[p.position.x, p.position.y, p.position.z] for p in self.ee_pose_history]
- )
-
- # Calculate standard deviation of positions
- std_devs = np.std(positions, axis=0)
- # Check if all standard deviations are below stuck threshold
- is_stuck = np.all(std_devs < self.stuck_pose_threshold)
-
- return is_stuck # type: ignore[return-value]
-
- def check_reach_and_adjust(self) -> bool:
- """
- Check if robot has reached the current executed pose while waiting.
- Handles timeout internally by failing the task.
- Also detects if the robot is stuck (not moving towards target).
-
- Returns:
- True if reached, False if still waiting or not in waiting state
- """
- if not self.waiting_for_reach or not self.current_executed_pose:
- return False
-
- # Get current end-effector pose
- ee_pose = self.arm.get_ee_pose()
- target_pose = self.current_executed_pose
-
- # Check for timeout - this will fail task and reset if timeout occurred
- timed_out, _time_elapsed = self._check_reach_timeout()
- if timed_out:
- return False
-
- self.ee_pose_history.append(ee_pose)
-
- # Check if robot is stuck
- is_stuck = self._check_if_stuck()
- if is_stuck:
- if self.grasp_stage == GraspStage.RETRACT or self.grasp_stage == GraspStage.PLACE:
- self.waiting_for_reach = False
- self.waiting_start_time = None
- self.stuck_count = 0
- self.ee_pose_history.clear()
- return True
- self.stuck_count += 1
- pitch_degrees = self.calculate_dynamic_grasp_pitch(target_pose)
- if self.stuck_count % 2 == 0:
- pitch_degrees += self.stuck_pose_adjustment_degrees * (1 + self.stuck_count // 2)
- else:
- pitch_degrees -= self.stuck_pose_adjustment_degrees * (1 + self.stuck_count // 2)
-
- pitch_degrees = max(
- self.min_grasp_pitch_degrees, min(self.max_grasp_pitch_degrees, pitch_degrees)
- )
- updated_target_pose = update_target_grasp_pose(target_pose, ee_pose, 0.0, pitch_degrees)
- self.arm.cmd_ee_pose(updated_target_pose)
- self.current_executed_pose = updated_target_pose
- self.ee_pose_history.clear()
- self.waiting_for_reach = True
- self.waiting_start_time = time.time()
- return False
-
- if self.stuck_count >= self.max_stuck_reattempts:
- self.task_failed = True
- self.reset_to_idle()
- return False
-
- if is_target_reached(target_pose, ee_pose, self.pbvs.target_tolerance):
- self.waiting_for_reach = False
- self.waiting_start_time = None
- self.stuck_count = 0
- self.ee_pose_history.clear()
- return True
- return False
-
- def _update_tracking(self, detection_3d_array: Detection3DArray | None) -> bool:
- """Update tracking with new detections."""
- if not detection_3d_array or not self.pbvs:
- return False
-
- target_tracked = self.pbvs.update_tracking(detection_3d_array)
- if target_tracked:
- self.target_updated = True
- self.last_valid_target = self.pbvs.get_current_target()
- return target_tracked
-
- def reset_to_idle(self) -> None:
- """Reset the manipulation system to IDLE state."""
- if self.pbvs:
- self.pbvs.clear_target()
- self.grasp_stage = GraspStage.IDLE
- self.reached_poses.clear()
- self.ee_pose_history.clear()
- self.adjustment_count = 0
- self.waiting_for_reach = False
- self.current_executed_pose = None
- self.target_updated = False
- self.stabilization_start_time = None
- self.grasp_reached_time = None
- self.waiting_start_time = None
- self.pick_success = None
- self.final_pregrasp_pose = None
- self.overall_success = None
- self.place_pose = None
- self.retract_pose = None
- self.stuck_count = 0
-
- self.arm.gotoObserve()
-
- def execute_idle(self) -> None:
- """Execute idle stage."""
- pass
-
- def execute_pre_grasp(self) -> None:
- """Execute pre-grasp stage: visual servoing to pre-grasp position."""
- if self.waiting_for_reach:
- if self.check_reach_and_adjust():
- self.reached_poses.append(self.current_executed_pose)
- self.target_updated = False
- time.sleep(0.2)
- return
- if (
- self.stabilization_start_time
- and (time.time() - self.stabilization_start_time) > self.stabilization_timeout
- ):
- logger.warning(
- f"Failed to get stable grasp after {self.stabilization_timeout} seconds, resetting"
- )
- self.task_failed = True
- self.reset_to_idle()
- return
-
- ee_pose = self.arm.get_ee_pose()
- dynamic_pitch = self.calculate_dynamic_grasp_pitch(self.pbvs.current_target.bbox.center) # type: ignore[attr-defined]
-
- _, _, _, has_target, target_pose = self.pbvs.compute_control( # type: ignore[attr-defined]
- ee_pose, self.pregrasp_distance, dynamic_pitch
- )
- if target_pose and has_target:
- # Validate target pose is within workspace
- if not self.check_within_workspace(target_pose):
- self.task_failed = True
- self.reset_to_idle()
- return
-
- if self.check_target_stabilized():
- logger.info("Target stabilized, transitioning to GRASP")
- self.final_pregrasp_pose = self.current_executed_pose
- self.grasp_stage = GraspStage.GRASP
- self.adjustment_count = 0
- self.waiting_for_reach = False
- elif not self.waiting_for_reach and self.target_updated:
- self.arm.cmd_ee_pose(target_pose)
- self.current_executed_pose = target_pose
- self.waiting_for_reach = True
- self.waiting_start_time = time.time() # type: ignore[assignment]
- self.target_updated = False
- self.adjustment_count += 1
- time.sleep(0.2)
-
- def execute_grasp(self) -> None:
- """Execute grasp stage: move to final grasp position."""
- if self.waiting_for_reach:
- if self.check_reach_and_adjust() and not self.grasp_reached_time:
- self.grasp_reached_time = time.time() # type: ignore[assignment]
- return
-
- if self.grasp_reached_time:
- if (time.time() - self.grasp_reached_time) >= self.grasp_close_delay:
- logger.info("Grasp delay completed, closing gripper")
- self.grasp_stage = GraspStage.CLOSE_AND_RETRACT
- return
-
- if self.last_valid_target:
- # Calculate dynamic pitch for current target
- dynamic_pitch = self.calculate_dynamic_grasp_pitch(self.last_valid_target.bbox.center)
- normalized_pitch = dynamic_pitch / 90.0
- grasp_distance = -self.grasp_distance_range + (
- 2 * self.grasp_distance_range * normalized_pitch
- )
-
- ee_pose = self.arm.get_ee_pose()
- _, _, _, has_target, target_pose = self.pbvs.compute_control(
- ee_pose, grasp_distance, dynamic_pitch
- )
-
- if target_pose and has_target:
- # Validate grasp pose is within workspace
- if not self.check_within_workspace(target_pose):
- self.task_failed = True
- self.reset_to_idle()
- return
-
- object_width = self.last_valid_target.bbox.size.x
- gripper_opening = max(
- 0.005, min(object_width + self.grasp_width_offset, self.gripper_max_opening)
- )
-
- logger.info(f"Executing grasp: gripper={gripper_opening * 1000:.1f}mm")
- self.arm.cmd_gripper_ctrl(gripper_opening)
- self.arm.cmd_ee_pose(target_pose, line_mode=True)
- self.current_executed_pose = target_pose
- self.waiting_for_reach = True
- self.waiting_start_time = time.time()
-
- def execute_close_and_retract(self) -> None:
- """Execute the retraction sequence after gripper has been closed."""
- if self.waiting_for_reach and self.final_pregrasp_pose:
- if self.check_reach_and_adjust():
- logger.info("Reached pre-grasp retraction position")
- self.pick_success = self.arm.gripper_object_detected()
- if self.pick_success:
- logger.info("Object successfully grasped!")
- if self.place_target_position is not None:
- logger.info("Transitioning to PLACE stage")
- self.grasp_stage = GraspStage.PLACE
- else:
- self.overall_success = True
- else:
- logger.warning("No object detected in gripper")
- self.task_failed = True
- self.overall_success = False
- return
- if not self.waiting_for_reach:
- logger.info("Retracting to pre-grasp position")
- self.arm.cmd_ee_pose(self.final_pregrasp_pose, line_mode=True)
- self.current_executed_pose = self.final_pregrasp_pose
- self.arm.close_gripper()
- self.waiting_for_reach = True
- self.waiting_start_time = time.time() # type: ignore[assignment]
-
- def execute_place(self) -> None:
- """Execute place stage: move to place position and release object."""
- if self.waiting_for_reach:
- # Use the already executed pose instead of recalculating
- if self.check_reach_and_adjust():
- logger.info("Reached place position, releasing gripper")
- self.arm.release_gripper()
- time.sleep(1.0)
- self.place_pose = self.current_executed_pose
- logger.info("Transitioning to RETRACT stage")
- self.grasp_stage = GraspStage.RETRACT
- return
-
- if not self.waiting_for_reach:
- place_pose = self.get_place_target_pose()
- if place_pose:
- logger.info("Moving to place position")
- self.arm.cmd_ee_pose(place_pose, line_mode=True)
- self.current_executed_pose = place_pose # type: ignore[assignment]
- self.waiting_for_reach = True
- self.waiting_start_time = time.time() # type: ignore[assignment]
- else:
- logger.error("Failed to get place target pose")
- self.task_failed = True
- self.overall_success = False # type: ignore[assignment]
-
- def execute_retract(self) -> None:
- """Execute retract stage: retract from place position."""
- if self.waiting_for_reach and self.retract_pose:
- if self.check_reach_and_adjust():
- logger.info("Reached retract position")
- logger.info("Returning to observe position")
- self.arm.gotoObserve()
- self.arm.close_gripper()
- self.overall_success = True
- logger.info("Pick and place completed successfully!")
- return
-
- if not self.waiting_for_reach:
- if self.place_pose:
- pose_pitch = self.calculate_dynamic_grasp_pitch(self.place_pose)
- self.retract_pose = update_target_grasp_pose(
- self.place_pose, self.home_pose, self.retract_distance, pose_pitch
- )
- logger.info("Retracting from place position")
- self.arm.cmd_ee_pose(self.retract_pose, line_mode=True)
- self.current_executed_pose = self.retract_pose
- self.waiting_for_reach = True
- self.waiting_start_time = time.time()
- else:
- logger.error("No place pose stored for retraction")
- self.task_failed = True
- self.overall_success = False # type: ignore[assignment]
-
- def capture_and_process(
- self,
- ) -> tuple[np.ndarray | None, Detection3DArray | None, Detection2DArray | None, Pose | None]: # type: ignore[type-arg]
- """Capture frame from camera data and process detections."""
- if self.latest_rgb is None or self.latest_depth is None or self.detector is None:
- return None, None, None, None
-
- ee_pose = self.arm.get_ee_pose()
- ee_transform = pose_to_matrix(ee_pose)
- camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera)
- camera_pose = matrix_to_pose(camera_transform)
- detection_3d_array, detection_2d_array = self.detector.process_frame(
- self.latest_rgb, self.latest_depth, camera_transform
- )
-
- return self.latest_rgb, detection_3d_array, detection_2d_array, camera_pose
-
- def pick_target(self, x: int, y: int) -> bool:
- """Select a target object at the given pixel coordinates."""
- if not self.last_detection_2d_array or not self.last_detection_3d_array:
- logger.warning("No detections available for target selection")
- return False
-
- clicked_3d = find_clicked_detection(
- (x, y), self.last_detection_2d_array.detections, self.last_detection_3d_array.detections
- )
- if clicked_3d and self.pbvs:
- # Validate workspace
- if not self.check_within_workspace(clicked_3d.bbox.center):
- self.task_failed = True
- return False
-
- self.pbvs.set_target(clicked_3d)
-
- if clicked_3d.bbox and clicked_3d.bbox.size:
- self.target_object_height = clicked_3d.bbox.size.z
- logger.info(f"Target object height: {self.target_object_height:.3f}m")
-
- position = clicked_3d.bbox.center.position
- logger.info(
- f"Target selected: ID={clicked_3d.id}, pos=({position.x:.3f}, {position.y:.3f}, {position.z:.3f})"
- )
- self.grasp_stage = GraspStage.PRE_GRASP
- self.reached_poses.clear()
- self.adjustment_count = 0
- self.waiting_for_reach = False
- self.current_executed_pose = None
- self.stabilization_start_time = time.time()
- return True
- return False
-
- def update(self) -> dict[str, Any] | None:
- """Main update function that handles capture, processing, control, and visualization."""
- rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process()
- if rgb is None:
- return None
-
- self.last_detection_3d_array = detection_3d_array # type: ignore[assignment]
- self.last_detection_2d_array = detection_2d_array # type: ignore[assignment]
- if self.target_click:
- x, y = self.target_click
- if self.pick_target(x, y):
- self.target_click = None
-
- if (
- detection_3d_array
- and self.grasp_stage in [GraspStage.PRE_GRASP, GraspStage.GRASP]
- and not self.waiting_for_reach
- ):
- self._update_tracking(detection_3d_array)
- stage_handlers = {
- GraspStage.IDLE: self.execute_idle,
- GraspStage.PRE_GRASP: self.execute_pre_grasp,
- GraspStage.GRASP: self.execute_grasp,
- GraspStage.CLOSE_AND_RETRACT: self.execute_close_and_retract,
- GraspStage.PLACE: self.execute_place,
- GraspStage.RETRACT: self.execute_retract,
- }
- if self.grasp_stage in stage_handlers:
- stage_handlers[self.grasp_stage]()
-
- target_tracked = self.pbvs.get_current_target() is not None if self.pbvs else False
- ee_pose = self.arm.get_ee_pose()
- feedback = Feedback(
- grasp_stage=self.grasp_stage,
- target_tracked=target_tracked,
- current_executed_pose=self.current_executed_pose,
- current_ee_pose=ee_pose,
- current_camera_pose=camera_pose,
- target_pose=self.pbvs.target_grasp_pose if self.pbvs else None,
- waiting_for_reach=self.waiting_for_reach,
- success=self.overall_success,
- )
-
- if self.task_running:
- self.current_visualization = create_manipulation_visualization( # type: ignore[assignment]
- rgb, feedback, detection_3d_array, detection_2d_array
- )
-
- if self.current_visualization is not None:
- self._publish_visualization(self.current_visualization)
-
- return feedback # type: ignore[return-value]
-
- def _publish_visualization(self, viz_image: np.ndarray) -> None: # type: ignore[type-arg]
- """Publish visualization image to LCM."""
- try:
- viz_rgb = cv2.cvtColor(viz_image, cv2.COLOR_BGR2RGB)
- msg = Image.from_numpy(viz_rgb)
- self.viz_image.publish(msg)
- except Exception as e:
- logger.error(f"Error publishing visualization: {e}")
-
- def check_target_stabilized(self) -> bool:
- """Check if the commanded poses have stabilized."""
- if len(self.reached_poses) < self.reached_poses.maxlen: # type: ignore[operator]
- return False
-
- positions = np.array(
- [[p.position.x, p.position.y, p.position.z] for p in self.reached_poses]
- )
- std_devs = np.std(positions, axis=0)
- return np.all(std_devs < self.pose_stabilization_threshold) # type: ignore[return-value]
-
- def get_place_target_pose(self) -> Pose | None:
- """Get the place target pose with z-offset applied based on object height."""
- if self.place_target_position is None:
- return None
-
- place_pos = self.place_target_position.copy()
- if self.target_object_height is not None:
- z_offset = self.target_object_height / 2.0
- place_pos[2] += z_offset + 0.1
-
- place_center_pose = Pose(
- position=Vector3(place_pos[0], place_pos[1], place_pos[2]),
- orientation=Quaternion(0.0, 0.0, 0.0, 1.0),
- )
-
- ee_pose = self.arm.get_ee_pose()
-
- # Calculate dynamic pitch for place position
- dynamic_pitch = self.calculate_dynamic_grasp_pitch(place_center_pose)
-
- place_pose = update_target_grasp_pose(
- place_center_pose,
- ee_pose,
- grasp_distance=0.0,
- grasp_pitch_degrees=dynamic_pitch,
- )
-
- return place_pose
diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py
deleted file mode 100644
index f94c233834..0000000000
--- a/dimos/manipulation/visual_servoing/pbvs.py
+++ /dev/null
@@ -1,488 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Position-Based Visual Servoing (PBVS) system for robotic manipulation.
-Supports both eye-in-hand and eye-to-hand configurations.
-"""
-
-from collections import deque
-
-from dimos_lcm.vision_msgs import Detection3D
-import numpy as np
-from scipy.spatial.transform import Rotation as R # type: ignore[import-untyped]
-
-from dimos.manipulation.visual_servoing.utils import (
- create_pbvs_visualization,
- find_best_object_match,
- is_target_reached,
- update_target_grasp_pose,
-)
-from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3
-from dimos.msgs.vision_msgs import Detection3DArray
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class PBVS:
- """
- High-level Position-Based Visual Servoing orchestrator.
-
- Handles:
- - Object tracking and target management
- - Pregrasp distance computation
- - Grasp pose generation
- - Coordination with low-level controller
-
- Note: This class is agnostic to camera mounting (eye-in-hand vs eye-to-hand).
- The caller is responsible for providing appropriate camera and EE poses.
- """
-
- def __init__(
- self,
- position_gain: float = 0.5,
- rotation_gain: float = 0.3,
- max_velocity: float = 0.1, # m/s
- max_angular_velocity: float = 0.5, # rad/s
- target_tolerance: float = 0.01, # 1cm
- max_tracking_distance_threshold: float = 0.12, # Max distance for target tracking (m)
- min_size_similarity: float = 0.6, # Min size similarity threshold (0.0-1.0)
- direct_ee_control: bool = True, # If True, output target poses instead of velocities
- ) -> None:
- """
- Initialize PBVS system.
-
- Args:
- position_gain: Proportional gain for position control
- rotation_gain: Proportional gain for rotation control
- max_velocity: Maximum linear velocity command magnitude (m/s)
- max_angular_velocity: Maximum angular velocity command magnitude (rad/s)
- target_tolerance: Distance threshold for considering target reached (m)
- max_tracking_distance: Maximum distance for valid target tracking (m)
- min_size_similarity: Minimum size similarity for valid target tracking (0.0-1.0)
- direct_ee_control: If True, output target poses instead of velocity commands
- """
- # Initialize low-level controller only if not in direct control mode
- if not direct_ee_control:
- self.controller = PBVSController(
- position_gain=position_gain,
- rotation_gain=rotation_gain,
- max_velocity=max_velocity,
- max_angular_velocity=max_angular_velocity,
- target_tolerance=target_tolerance,
- )
- else:
- self.controller = None # type: ignore[assignment]
-
- # Store parameters for direct mode error computation
- self.target_tolerance = target_tolerance
-
- # Target tracking parameters
- self.max_tracking_distance_threshold = max_tracking_distance_threshold
- self.min_size_similarity = min_size_similarity
- self.direct_ee_control = direct_ee_control
-
- # Target state
- self.current_target = None
- self.target_grasp_pose = None
-
- # Detection history for robust tracking
- self.detection_history_size = 3
- self.detection_history = deque(maxlen=self.detection_history_size) # type: ignore[var-annotated]
-
- # For direct control mode visualization
- self.last_position_error = None
- self.last_target_reached = False
-
- logger.info(
- f"Initialized PBVS system with controller gains: pos={position_gain}, rot={rotation_gain}, "
- f"tracking_thresholds: distance={max_tracking_distance_threshold}m, size={min_size_similarity:.2f}"
- )
-
- def set_target(self, target_object: Detection3D) -> bool:
- """
- Set a new target object for servoing.
-
- Args:
- target_object: Detection3D object
-
- Returns:
- True if target was set successfully
- """
- if target_object and target_object.bbox and target_object.bbox.center:
- self.current_target = target_object
- self.target_grasp_pose = None # Will be computed when needed
- logger.info(f"New target set: ID {target_object.id}")
- return True
- return False
-
- def clear_target(self) -> None:
- """Clear the current target."""
- self.current_target = None
- self.target_grasp_pose = None
- self.last_position_error = None
- self.last_target_reached = False
- self.detection_history.clear()
- if self.controller:
- self.controller.clear_state()
- logger.info("Target cleared")
-
- def get_current_target(self) -> Detection3D | None:
- """
- Get the current target object.
-
- Returns:
- Current target Detection3D or None if no target selected
- """
- return self.current_target
-
- def update_tracking(self, new_detections: Detection3DArray | None = None) -> bool:
- """
- Update target tracking with new detections using a rolling window.
- If tracking is lost, keeps the old target pose.
-
- Args:
- new_detections: Optional new detections for target tracking
-
- Returns:
- True if target was successfully tracked, False if lost (but target is kept)
- """
- # Check if we have a current target
- if not self.current_target:
- return False
-
- # Add new detections to history if provided
- if new_detections is not None and new_detections.detections_length > 0:
- self.detection_history.append(new_detections)
-
- # If no detection history, can't track
- if not self.detection_history:
- logger.debug("No detection history for target tracking - using last known pose")
- return False
-
- # Collect all candidates from detection history
- all_candidates = []
- for detection_array in self.detection_history:
- all_candidates.extend(detection_array.detections)
-
- if not all_candidates:
- logger.debug("No candidates in detection history")
- return False
-
- # Use stage-dependent distance threshold
- max_distance = self.max_tracking_distance_threshold
-
- # Find best match across all recent detections
- match_result = find_best_object_match(
- target_obj=self.current_target,
- candidates=all_candidates,
- max_distance=max_distance,
- min_size_similarity=self.min_size_similarity,
- )
-
- if match_result.is_valid_match:
- self.current_target = match_result.matched_object
- self.target_grasp_pose = None # Recompute grasp pose
- logger.debug(
- f"Target tracking successful: distance={match_result.distance:.3f}m, "
- f"size_similarity={match_result.size_similarity:.2f}, "
- f"confidence={match_result.confidence:.2f}"
- )
- return True
-
- logger.debug(
- f"Target tracking lost across {len(self.detection_history)} frames: "
- f"distance={match_result.distance:.3f}m, "
- f"size_similarity={match_result.size_similarity:.2f}, "
- f"thresholds: distance={max_distance:.3f}m, size={self.min_size_similarity:.2f}"
- )
- return False
-
- def compute_control(
- self,
- ee_pose: Pose,
- grasp_distance: float = 0.15,
- grasp_pitch_degrees: float = 45.0,
- ) -> tuple[Vector3 | None, Vector3 | None, bool, bool, Pose | None]:
- """
- Compute PBVS control with position and orientation servoing.
-
- Args:
- ee_pose: Current end-effector pose
- grasp_distance: Distance to maintain from target (meters)
-
- Returns:
- Tuple of (velocity_command, angular_velocity_command, target_reached, has_target, target_pose)
- - velocity_command: Linear velocity vector or None if no target (None in direct_ee_control mode)
- - angular_velocity_command: Angular velocity vector or None if no target (None in direct_ee_control mode)
- - target_reached: True if within target tolerance
- - has_target: True if currently tracking a target
- - target_pose: Target EE pose (only in direct_ee_control mode, otherwise None)
- """
- # Check if we have a target
- if not self.current_target:
- return None, None, False, False, None
-
- # Update target grasp pose with provided distance and pitch
- self.target_grasp_pose = update_target_grasp_pose(
- self.current_target.bbox.center, ee_pose, grasp_distance, grasp_pitch_degrees
- )
-
- if self.target_grasp_pose is None:
- logger.warning("Failed to compute grasp pose")
- return None, None, False, False, None
-
- # Compute errors for visualization before checking if reached (in case pose gets cleared)
- if self.direct_ee_control and self.target_grasp_pose:
- self.last_position_error = Vector3(
- self.target_grasp_pose.position.x - ee_pose.position.x,
- self.target_grasp_pose.position.y - ee_pose.position.y,
- self.target_grasp_pose.position.z - ee_pose.position.z,
- )
-
- # Check if target reached using our separate function
- target_reached = is_target_reached(self.target_grasp_pose, ee_pose, self.target_tolerance)
-
- # Return appropriate values based on control mode
- if self.direct_ee_control:
- # Direct control mode
- if self.target_grasp_pose:
- self.last_target_reached = target_reached
- # Return has_target=True since we have a target
- return None, None, target_reached, True, self.target_grasp_pose
- else:
- return None, None, False, True, None
- else:
- # Velocity control mode - use controller
- velocity_cmd, angular_velocity_cmd, _controller_reached = (
- self.controller.compute_control(ee_pose, self.target_grasp_pose)
- )
- # Return has_target=True since we have a target, regardless of tracking status
- return velocity_cmd, angular_velocity_cmd, target_reached, True, None
-
- def create_status_overlay( # type: ignore[no-untyped-def]
- self,
- image: np.ndarray, # type: ignore[type-arg]
- grasp_stage=None,
- ) -> np.ndarray: # type: ignore[type-arg]
- """
- Create PBVS status overlay on image.
-
- Args:
- image: Input image
- grasp_stage: Current grasp stage (optional)
-
- Returns:
- Image with PBVS status overlay
- """
- stage_value = grasp_stage.value if grasp_stage else "idle"
- return create_pbvs_visualization(
- image,
- self.current_target,
- self.last_position_error,
- self.last_target_reached,
- stage_value,
- )
-
-
-class PBVSController:
- """
- Low-level Position-Based Visual Servoing controller.
- Pure control logic that computes velocity commands from poses.
-
- Handles:
- - Position and orientation error computation
- - Velocity command generation with gain control
- - Target reached detection
- """
-
- def __init__(
- self,
- position_gain: float = 0.5,
- rotation_gain: float = 0.3,
- max_velocity: float = 0.1, # m/s
- max_angular_velocity: float = 0.5, # rad/s
- target_tolerance: float = 0.01, # 1cm
- ) -> None:
- """
- Initialize PBVS controller.
-
- Args:
- position_gain: Proportional gain for position control
- rotation_gain: Proportional gain for rotation control
- max_velocity: Maximum linear velocity command magnitude (m/s)
- max_angular_velocity: Maximum angular velocity command magnitude (rad/s)
- target_tolerance: Distance threshold for considering target reached (m)
- """
- self.position_gain = position_gain
- self.rotation_gain = rotation_gain
- self.max_velocity = max_velocity
- self.max_angular_velocity = max_angular_velocity
- self.target_tolerance = target_tolerance
-
- self.last_position_error = None
- self.last_rotation_error = None
- self.last_velocity_cmd = None
- self.last_angular_velocity_cmd = None
- self.last_target_reached = False
-
- logger.info(
- f"Initialized PBVS controller: pos_gain={position_gain}, rot_gain={rotation_gain}, "
- f"max_vel={max_velocity}m/s, max_ang_vel={max_angular_velocity}rad/s, "
- f"target_tolerance={target_tolerance}m"
- )
-
- def clear_state(self) -> None:
- """Clear controller state."""
- self.last_position_error = None
- self.last_rotation_error = None
- self.last_velocity_cmd = None
- self.last_angular_velocity_cmd = None
- self.last_target_reached = False
-
- def compute_control(
- self, ee_pose: Pose, grasp_pose: Pose
- ) -> tuple[Vector3 | None, Vector3 | None, bool]:
- """
- Compute PBVS control with position and orientation servoing.
-
- Args:
- ee_pose: Current end-effector pose
- grasp_pose: Target grasp pose
-
- Returns:
- Tuple of (velocity_command, angular_velocity_command, target_reached)
- - velocity_command: Linear velocity vector
- - angular_velocity_command: Angular velocity vector
- - target_reached: True if within target tolerance
- """
- # Calculate position error (target - EE position)
- error = Vector3(
- grasp_pose.position.x - ee_pose.position.x,
- grasp_pose.position.y - ee_pose.position.y,
- grasp_pose.position.z - ee_pose.position.z,
- )
- self.last_position_error = error # type: ignore[assignment]
-
- # Compute velocity command with proportional control
- velocity_cmd = Vector3(
- error.x * self.position_gain,
- error.y * self.position_gain,
- error.z * self.position_gain,
- )
-
- # Limit velocity magnitude
- vel_magnitude = np.linalg.norm([velocity_cmd.x, velocity_cmd.y, velocity_cmd.z])
- if vel_magnitude > self.max_velocity:
- scale = self.max_velocity / vel_magnitude
- velocity_cmd = Vector3(
- float(velocity_cmd.x * scale),
- float(velocity_cmd.y * scale),
- float(velocity_cmd.z * scale),
- )
-
- self.last_velocity_cmd = velocity_cmd # type: ignore[assignment]
-
- # Compute angular velocity for orientation control
- angular_velocity_cmd = self._compute_angular_velocity(grasp_pose.orientation, ee_pose)
-
- # Check if target reached
- error_magnitude = np.linalg.norm([error.x, error.y, error.z])
- target_reached = bool(error_magnitude < self.target_tolerance)
- self.last_target_reached = target_reached
-
- return velocity_cmd, angular_velocity_cmd, target_reached
-
- def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) -> Vector3:
- """
- Compute angular velocity commands for orientation control.
- Uses quaternion error computation for better numerical stability.
-
- Args:
- target_rot: Target orientation (quaternion)
- current_pose: Current EE pose
-
- Returns:
- Angular velocity command as Vector3
- """
- # Use quaternion error for better numerical stability
-
- # Convert to scipy Rotation objects
- target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w])
- current_rot_scipy = R.from_quat(
- [
- current_pose.orientation.x,
- current_pose.orientation.y,
- current_pose.orientation.z,
- current_pose.orientation.w,
- ]
- )
-
- # Compute rotation error: error = target * current^(-1)
- error_rot = target_rot_scipy * current_rot_scipy.inv()
-
- # Convert to axis-angle representation for control
- error_axis_angle = error_rot.as_rotvec()
-
- # Use axis-angle directly as angular velocity error (small angle approximation)
- roll_error = error_axis_angle[0]
- pitch_error = error_axis_angle[1]
- yaw_error = error_axis_angle[2]
-
- self.last_rotation_error = Vector3(roll_error, pitch_error, yaw_error) # type: ignore[assignment]
-
- # Apply proportional control
- angular_velocity = Vector3(
- roll_error * self.rotation_gain,
- pitch_error * self.rotation_gain,
- yaw_error * self.rotation_gain,
- )
-
- # Limit angular velocity magnitude
- ang_vel_magnitude = np.sqrt(
- angular_velocity.x**2 + angular_velocity.y**2 + angular_velocity.z**2
- )
- if ang_vel_magnitude > self.max_angular_velocity:
- scale = self.max_angular_velocity / ang_vel_magnitude
- angular_velocity = Vector3(
- angular_velocity.x * scale, angular_velocity.y * scale, angular_velocity.z * scale
- )
-
- self.last_angular_velocity_cmd = angular_velocity # type: ignore[assignment]
-
- return angular_velocity
-
- def create_status_overlay(
- self,
- image: np.ndarray, # type: ignore[type-arg]
- current_target: Detection3D | None = None,
- ) -> np.ndarray: # type: ignore[type-arg]
- """
- Create PBVS status overlay on image.
-
- Args:
- image: Input image
- current_target: Current target object Detection3D (for display)
-
- Returns:
- Image with PBVS status overlay
- """
- return create_pbvs_visualization(
- image,
- current_target,
- self.last_position_error,
- self.last_target_reached,
- "velocity_control",
- )
diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py
deleted file mode 100644
index 5922739429..0000000000
--- a/dimos/manipulation/visual_servoing/utils.py
+++ /dev/null
@@ -1,801 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from typing import Any
-
-import cv2
-from dimos_lcm.vision_msgs import Detection2D, Detection3D
-import numpy as np
-
-from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3
-from dimos.perception.common.utils import project_2d_points_to_3d
-from dimos.perception.detection2d.utils import plot_results
-from dimos.utils.transform_utils import (
- compose_transforms,
- euler_to_quaternion,
- get_distance,
- matrix_to_pose,
- offset_distance,
- optical_to_robot_frame,
- pose_to_matrix,
- robot_to_optical_frame,
- yaw_towards_point,
-)
-
-
-def match_detection_by_id(
- detection_3d: Detection3D, detections_3d: list[Detection3D], detections_2d: list[Detection2D]
-) -> Detection2D | None:
- """
- Find the corresponding Detection2D for a given Detection3D.
-
- Args:
- detection_3d: The Detection3D to match
- detections_3d: List of all Detection3D objects
- detections_2d: List of all Detection2D objects (must be 1:1 correspondence)
-
- Returns:
- Corresponding Detection2D if found, None otherwise
- """
- for i, det_3d in enumerate(detections_3d):
- if det_3d.id == detection_3d.id and i < len(detections_2d):
- return detections_2d[i]
- return None
-
-
-def transform_pose(
- obj_pos: np.ndarray, # type: ignore[type-arg]
- obj_orientation: np.ndarray, # type: ignore[type-arg]
- transform_matrix: np.ndarray, # type: ignore[type-arg]
- to_optical: bool = False,
- to_robot: bool = False,
-) -> Pose:
- """
- Transform object pose with optional frame convention conversion.
-
- Args:
- obj_pos: Object position [x, y, z]
- obj_orientation: Object orientation [roll, pitch, yaw] in radians
- transform_matrix: 4x4 transformation matrix from camera frame to desired frame
- to_optical: If True, input is in robot frame → convert result to optical frame
- to_robot: If True, input is in optical frame → convert to robot frame first
-
- Returns:
- Object pose in desired frame as Pose
- """
- # Convert euler angles to quaternion using utility function
- euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2])
- obj_orientation_quat = euler_to_quaternion(euler_vector)
-
- input_pose = Pose(
- position=Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), orientation=obj_orientation_quat
- )
-
- # Apply input frame conversion based on flags
- if to_robot:
- # Input is in optical frame → convert to robot frame first
- pose_for_transform = optical_to_robot_frame(input_pose)
- else:
- # Default or to_optical: use input pose as-is
- pose_for_transform = input_pose
-
- # Create transformation matrix from pose (relative to camera)
- T_camera_object = pose_to_matrix(pose_for_transform)
-
- # Use compose_transforms to combine transformations
- T_desired_object = compose_transforms(transform_matrix, T_camera_object)
-
- # Convert back to pose
- result_pose = matrix_to_pose(T_desired_object)
-
- # Apply output frame conversion based on flags
- if to_optical:
- # Input was robot frame → convert result to optical frame
- desired_pose = robot_to_optical_frame(result_pose)
- else:
- # Default or to_robot: use result as-is
- desired_pose = result_pose
-
- return desired_pose
-
-
-def transform_points_3d(
- points_3d: np.ndarray, # type: ignore[type-arg]
- transform_matrix: np.ndarray, # type: ignore[type-arg]
- to_optical: bool = False,
- to_robot: bool = False,
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Transform 3D points with optional frame convention conversion.
- Applies the same transformation pipeline as transform_pose but for multiple points.
-
- Args:
- points_3d: Nx3 array of 3D points [x, y, z]
- transform_matrix: 4x4 transformation matrix from camera frame to desired frame
- to_optical: If True, input is in robot frame → convert result to optical frame
- to_robot: If True, input is in optical frame → convert to robot frame first
-
- Returns:
- Nx3 array of transformed 3D points in desired frame
- """
- if points_3d.size == 0:
- return np.zeros((0, 3), dtype=np.float32)
-
- points_3d = np.asarray(points_3d)
- if points_3d.ndim == 1:
- points_3d = points_3d.reshape(1, -1)
-
- transformed_points = []
-
- for point in points_3d:
- input_point_pose = Pose(
- position=Vector3(point[0], point[1], point[2]),
- orientation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity quaternion
- )
-
- # Apply input frame conversion based on flags
- if to_robot:
- # Input is in optical frame → convert to robot frame first
- pose_for_transform = optical_to_robot_frame(input_point_pose)
- else:
- # Default or to_optical: use input pose as-is
- pose_for_transform = input_point_pose
-
- # Create transformation matrix from point pose (relative to camera)
- T_camera_point = pose_to_matrix(pose_for_transform)
-
- # Use compose_transforms to combine transformations
- T_desired_point = compose_transforms(transform_matrix, T_camera_point)
-
- # Convert back to pose
- result_pose = matrix_to_pose(T_desired_point)
-
- # Apply output frame conversion based on flags
- if to_optical:
- # Input was robot frame → convert result to optical frame
- desired_pose = robot_to_optical_frame(result_pose)
- else:
- # Default or to_robot: use result as-is
- desired_pose = result_pose
-
- transformed_point = [
- desired_pose.position.x,
- desired_pose.position.y,
- desired_pose.position.z,
- ]
- transformed_points.append(transformed_point)
-
- return np.array(transformed_points, dtype=np.float32)
-
-
-def select_points_from_depth(
- depth_image: np.ndarray, # type: ignore[type-arg]
- target_point: tuple[int, int],
- camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg]
- radius: int = 5,
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Select points around a target point within a bounding box and project them to 3D.
-
- Args:
- depth_image: Depth image in meters (H, W)
- target_point: (x, y) target point coordinates
- radius: Half-width of the bounding box (so bbox size is radius*2 x radius*2)
- camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix
-
- Returns:
- Nx3 array of 3D points (X, Y, Z) in camera frame
- """
- x_target, y_target = target_point
- height, width = depth_image.shape
-
- x_min = max(0, x_target - radius)
- x_max = min(width, x_target + radius)
- y_min = max(0, y_target - radius)
- y_max = min(height, y_target + radius)
-
- # Create coordinate grids for the bounding box (vectorized)
- y_coords, x_coords = np.meshgrid(range(y_min, y_max), range(x_min, x_max), indexing="ij")
-
- # Flatten to get all coordinate pairs
- x_flat = x_coords.flatten()
- y_flat = y_coords.flatten()
-
- # Extract corresponding depth values using advanced indexing
- depth_flat = depth_image[y_flat, x_flat]
-
- valid_mask = (depth_flat > 0) & np.isfinite(depth_flat)
-
- if not np.any(valid_mask):
- return np.zeros((0, 3), dtype=np.float32)
-
- points_2d = np.column_stack([x_flat[valid_mask], y_flat[valid_mask]]).astype(np.float32)
- depth_values = depth_flat[valid_mask].astype(np.float32)
-
- points_3d = project_2d_points_to_3d(points_2d, depth_values, camera_intrinsics)
-
- return points_3d
-
-
-def update_target_grasp_pose(
- target_pose: Pose, ee_pose: Pose, grasp_distance: float = 0.0, grasp_pitch_degrees: float = 45.0
-) -> Pose | None:
- """
- Update target grasp pose based on current target pose and EE pose.
-
- Args:
- target_pose: Target pose to grasp
- ee_pose: Current end-effector pose
- grasp_distance: Distance to maintain from target (pregrasp or grasp distance)
- grasp_pitch_degrees: Grasp pitch angle in degrees (default 90° for top-down)
-
- Returns:
- Target grasp pose or None if target is invalid
- """
-
- target_pos = target_pose.position
-
- # Calculate orientation pointing from target towards EE
- yaw_to_ee = yaw_towards_point(target_pos, ee_pose.position)
-
- # Create target pose with proper orientation
- # Convert grasp pitch from degrees to radians with mapping:
- # 0° (level) -> π/2 (1.57 rad), 90° (top-down) -> π (3.14 rad)
- pitch_radians = 1.57 + np.radians(grasp_pitch_degrees)
-
- # Convert euler angles to quaternion using utility function
- euler = Vector3(0.0, pitch_radians, yaw_to_ee) # roll=0, pitch=mapped, yaw=calculated
- target_orientation = euler_to_quaternion(euler)
-
- updated_pose = Pose(target_pos, target_orientation)
-
- if grasp_distance > 0.0:
- return offset_distance(updated_pose, grasp_distance)
- else:
- return updated_pose
-
-
-def is_target_reached(target_pose: Pose, current_pose: Pose, tolerance: float = 0.01) -> bool:
- """
- Check if the target pose has been reached within tolerance.
-
- Args:
- target_pose: Target pose to reach
- current_pose: Current pose (e.g., end-effector pose)
- tolerance: Distance threshold for considering target reached (meters, default 0.01 = 1cm)
-
- Returns:
- True if target is reached within tolerance, False otherwise
- """
- # Calculate position error using distance utility
- error_magnitude = get_distance(target_pose, current_pose)
- return error_magnitude < tolerance
-
-
-@dataclass
-class ObjectMatchResult:
- """Result of object matching with confidence metrics."""
-
- matched_object: Detection3D | None
- confidence: float
- distance: float
- size_similarity: float
- is_valid_match: bool
-
-
-def calculate_object_similarity(
- target_obj: Detection3D,
- candidate_obj: Detection3D,
- distance_weight: float = 0.6,
- size_weight: float = 0.4,
-) -> tuple[float, float, float]:
- """
- Calculate comprehensive similarity between two objects.
-
- Args:
- target_obj: Target Detection3D object
- candidate_obj: Candidate Detection3D object
- distance_weight: Weight for distance component (0-1)
- size_weight: Weight for size component (0-1)
-
- Returns:
- Tuple of (total_similarity, distance_m, size_similarity)
- """
- # Extract positions
- target_pos = target_obj.bbox.center.position
- candidate_pos = candidate_obj.bbox.center.position
-
- target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z])
- candidate_xyz = np.array([candidate_pos.x, candidate_pos.y, candidate_pos.z])
-
- # Calculate Euclidean distance
- distance = np.linalg.norm(target_xyz - candidate_xyz)
- distance_similarity = 1.0 / (1.0 + distance) # Exponential decay
-
- # Calculate size similarity by comparing each dimension individually
- size_similarity = 1.0 # Default if no size info
- target_size = target_obj.bbox.size
- candidate_size = candidate_obj.bbox.size
-
- if target_size and candidate_size:
- # Extract dimensions
- target_dims = [target_size.x, target_size.y, target_size.z]
- candidate_dims = [candidate_size.x, candidate_size.y, candidate_size.z]
-
- # Calculate similarity for each dimension pair
- dim_similarities = []
- for target_dim, candidate_dim in zip(target_dims, candidate_dims, strict=False):
- if target_dim == 0.0 and candidate_dim == 0.0:
- dim_similarities.append(1.0) # Both dimensions are zero
- elif target_dim == 0.0 or candidate_dim == 0.0:
- dim_similarities.append(0.0) # One dimension is zero, other is not
- else:
- # Calculate similarity as min/max ratio
- max_dim = max(target_dim, candidate_dim)
- min_dim = min(target_dim, candidate_dim)
- dim_similarity = min_dim / max_dim if max_dim > 0 else 0.0
- dim_similarities.append(dim_similarity)
-
- # Return average similarity across all dimensions
- size_similarity = np.mean(dim_similarities) if dim_similarities else 0.0 # type: ignore[assignment]
-
- # Weighted combination
- total_similarity = distance_weight * distance_similarity + size_weight * size_similarity
-
- return total_similarity, distance, size_similarity # type: ignore[return-value]
-
-
-def find_best_object_match(
- target_obj: Detection3D,
- candidates: list[Detection3D],
- max_distance: float = 0.1,
- min_size_similarity: float = 0.4,
- distance_weight: float = 0.7,
- size_weight: float = 0.3,
-) -> ObjectMatchResult:
- """
- Find the best matching object from candidates using distance and size criteria.
-
- Args:
- target_obj: Target Detection3D to match against
- candidates: List of candidate Detection3D objects
- max_distance: Maximum allowed distance for valid match (meters)
- min_size_similarity: Minimum size similarity for valid match (0-1)
- distance_weight: Weight for distance in similarity calculation
- size_weight: Weight for size in similarity calculation
-
- Returns:
- ObjectMatchResult with best match and confidence metrics
- """
- if not candidates or not target_obj.bbox or not target_obj.bbox.center:
- return ObjectMatchResult(None, 0.0, float("inf"), 0.0, False)
-
- best_match = None
- best_confidence = 0.0
- best_distance = float("inf")
- best_size_sim = 0.0
-
- for candidate in candidates:
- if not candidate.bbox or not candidate.bbox.center:
- continue
-
- similarity, distance, size_sim = calculate_object_similarity(
- target_obj, candidate, distance_weight, size_weight
- )
-
- # Check validity constraints
- is_valid = distance <= max_distance and size_sim >= min_size_similarity
-
- if is_valid and similarity > best_confidence:
- best_match = candidate
- best_confidence = similarity
- best_distance = distance
- best_size_sim = size_sim
-
- return ObjectMatchResult(
- matched_object=best_match,
- confidence=best_confidence,
- distance=best_distance,
- size_similarity=best_size_sim,
- is_valid_match=best_match is not None,
- )
-
-
-def parse_zed_pose(zed_pose_data: dict[str, Any]) -> Pose | None:
- """
- Parse ZED pose data dictionary into a Pose object.
-
- Args:
- zed_pose_data: Dictionary from ZEDCamera.get_pose() containing:
- - position: [x, y, z] in meters
- - rotation: [x, y, z, w] quaternion
- - euler_angles: [roll, pitch, yaw] in radians
- - valid: Whether pose is valid
-
- Returns:
- Pose object with position and orientation, or None if invalid
- """
- if not zed_pose_data or not zed_pose_data.get("valid", False):
- return None
-
- # Extract position
- position = zed_pose_data.get("position", [0, 0, 0])
- pos_vector = Vector3(position[0], position[1], position[2])
-
- quat = zed_pose_data["rotation"]
- orientation = Quaternion(quat[0], quat[1], quat[2], quat[3])
- return Pose(position=pos_vector, orientation=orientation)
-
-
-def estimate_object_depth(
- depth_image: np.ndarray, # type: ignore[type-arg]
- segmentation_mask: np.ndarray | None, # type: ignore[type-arg]
- bbox: list[float],
-) -> float:
- """
- Estimate object depth dimension using segmentation mask and depth data.
- Optimized for real-time performance.
-
- Args:
- depth_image: Depth image in meters
- segmentation_mask: Binary segmentation mask for the object
- bbox: Bounding box [x1, y1, x2, y2]
-
- Returns:
- Estimated object depth in meters
- """
- x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
-
- # Extract depth ROI once
- roi_depth = depth_image[y1:y2, x1:x2]
-
- if segmentation_mask is not None and segmentation_mask.size > 0:
- # Extract mask ROI efficiently
- mask_roi = (
- segmentation_mask[y1:y2, x1:x2]
- if segmentation_mask.shape != roi_depth.shape
- else segmentation_mask
- )
-
- # Fast mask application using boolean indexing
- valid_mask = mask_roi > 0
- if np.sum(valid_mask) > 10: # Early exit if not enough points
- masked_depths = roi_depth[valid_mask]
-
- # Fast percentile calculation using numpy's optimized functions
- depth_90 = np.percentile(masked_depths, 90)
- depth_10 = np.percentile(masked_depths, 10)
- depth_range = depth_90 - depth_10
-
- # Clamp to reasonable bounds with single operation
- return np.clip(depth_range, 0.02, 0.5) # type: ignore[no-any-return]
-
- # Fast fallback using area calculation
- bbox_area = (x2 - x1) * (y2 - y1)
-
- # Vectorized area-based estimation
- if bbox_area > 10000:
- return 0.15
- elif bbox_area > 5000:
- return 0.10
- else:
- return 0.05
-
-
-# ============= Visualization Functions =============
-
-
-def create_manipulation_visualization( # type: ignore[no-untyped-def]
- rgb_image: np.ndarray, # type: ignore[type-arg]
- feedback,
- detection_3d_array=None,
- detection_2d_array=None,
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Create simple visualization for manipulation class using feedback.
-
- Args:
- rgb_image: RGB image array
- feedback: Feedback object containing all state information
- detection_3d_array: Optional 3D detections for object visualization
- detection_2d_array: Optional 2D detections for object visualization
-
- Returns:
- BGR image with visualization overlays
- """
- # Convert to BGR for OpenCV
- viz = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
-
- # Draw detections if available
- if detection_3d_array and detection_2d_array:
- # Extract 2D bboxes
- bboxes_2d = []
- for det_2d in detection_2d_array.detections:
- if det_2d.bbox:
- x1 = det_2d.bbox.center.position.x - det_2d.bbox.size_x / 2
- y1 = det_2d.bbox.center.position.y - det_2d.bbox.size_y / 2
- x2 = det_2d.bbox.center.position.x + det_2d.bbox.size_x / 2
- y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2
- bboxes_2d.append([x1, y1, x2, y2])
-
- # Draw basic detections
- rgb_with_detections = visualize_detections_3d(
- rgb_image, detection_3d_array.detections, show_coordinates=True, bboxes_2d=bboxes_2d
- )
- viz = cv2.cvtColor(rgb_with_detections, cv2.COLOR_RGB2BGR)
-
- # Add manipulation status overlay
- status_y = 30
- cv2.putText(
- viz,
- "Eye-in-Hand Visual Servoing",
- (10, status_y),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.6,
- (0, 255, 255),
- 2,
- )
-
- # Stage information
- stage_text = f"Stage: {feedback.grasp_stage.value.upper()}"
- stage_color = {
- "idle": (100, 100, 100),
- "pre_grasp": (0, 255, 255),
- "grasp": (0, 255, 0),
- "close_and_retract": (255, 0, 255),
- "place": (0, 150, 255),
- "retract": (255, 150, 0),
- }.get(feedback.grasp_stage.value, (255, 255, 255))
-
- cv2.putText(
- viz,
- stage_text,
- (10, status_y + 25),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- stage_color,
- 1,
- )
-
- # Target tracking status
- if feedback.target_tracked:
- cv2.putText(
- viz,
- "Target: TRACKED",
- (10, status_y + 45),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (0, 255, 0),
- 1,
- )
- elif feedback.grasp_stage.value != "idle":
- cv2.putText(
- viz,
- "Target: LOST",
- (10, status_y + 45),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (0, 0, 255),
- 1,
- )
-
- # Waiting status
- if feedback.waiting_for_reach:
- cv2.putText(
- viz,
- "Status: WAITING FOR ROBOT",
- (10, status_y + 65),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (255, 255, 0),
- 1,
- )
-
- # Overall result
- if feedback.success is not None:
- result_text = "Pick & Place: SUCCESS" if feedback.success else "Pick & Place: FAILED"
- result_color = (0, 255, 0) if feedback.success else (0, 0, 255)
- cv2.putText(
- viz,
- result_text,
- (10, status_y + 85),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- result_color,
- 2,
- )
-
- # Control hints (bottom of image)
- hint_text = "Click object to grasp | s=STOP | r=RESET | g=RELEASE"
- cv2.putText(
- viz,
- hint_text,
- (10, viz.shape[0] - 10),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.4,
- (200, 200, 200),
- 1,
- )
-
- return viz
-
-
-def create_pbvs_visualization( # type: ignore[no-untyped-def]
- image: np.ndarray, # type: ignore[type-arg]
- current_target=None,
- position_error=None,
- target_reached: bool = False,
- grasp_stage: str = "idle",
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Create simple PBVS visualization overlay.
-
- Args:
- image: Input image (RGB or BGR)
- current_target: Current target Detection3D
- position_error: Position error Vector3
- target_reached: Whether target is reached
- grasp_stage: Current grasp stage string
-
- Returns:
- Image with PBVS overlay
- """
- viz = image.copy()
-
- # Only show PBVS info if we have a target
- if current_target is None:
- return viz
-
- # Create status panel at bottom
- height, width = viz.shape[:2]
- panel_height = 100
- panel_y = height - panel_height
-
- # Semi-transparent overlay
- overlay = viz.copy()
- cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1)
- viz = cv2.addWeighted(viz, 0.7, overlay, 0.3, 0)
-
- # PBVS Status
- y_offset = panel_y + 20
- cv2.putText(
- viz,
- "PBVS Control",
- (10, y_offset),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.6,
- (0, 255, 255),
- 2,
- )
-
- # Position error
- if position_error:
- error_mag = np.linalg.norm([position_error.x, position_error.y, position_error.z])
- error_text = f"Error: {error_mag * 100:.1f}cm"
- error_color = (0, 255, 0) if target_reached else (0, 255, 255)
- cv2.putText(
- viz,
- error_text,
- (10, y_offset + 25),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- error_color,
- 1,
- )
-
- # Stage
- cv2.putText(
- viz,
- f"Stage: {grasp_stage}",
- (10, y_offset + 45),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (255, 150, 255),
- 1,
- )
-
- # Target reached indicator
- if target_reached:
- cv2.putText(
- viz,
- "TARGET REACHED",
- (width - 150, y_offset + 25),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.6,
- (0, 255, 0),
- 2,
- )
-
- return viz
-
-
-def visualize_detections_3d(
- rgb_image: np.ndarray, # type: ignore[type-arg]
- detections: list[Detection3D],
- show_coordinates: bool = True,
- bboxes_2d: list[list[float]] | None = None,
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Visualize detections with 3D position overlay next to bounding boxes.
-
- Args:
- rgb_image: Original RGB image
- detections: List of Detection3D objects
- show_coordinates: Whether to show 3D coordinates next to bounding boxes
- bboxes_2d: Optional list of 2D bounding boxes corresponding to detections
-
- Returns:
- Visualization image
- """
- if not detections:
- return rgb_image.copy()
-
- # If no 2D bboxes provided, skip visualization
- if bboxes_2d is None:
- return rgb_image.copy()
-
- # Extract data for plot_results function
- bboxes = bboxes_2d
- track_ids = [int(det.id) if det.id.isdigit() else i for i, det in enumerate(detections)]
- class_ids = [i for i in range(len(detections))]
- confidences = [
- det.results[0].hypothesis.score if det.results_length > 0 else 0.0 for det in detections
- ]
- names = [
- det.results[0].hypothesis.class_id if det.results_length > 0 else "unknown"
- for det in detections
- ]
-
- # Use plot_results for basic visualization
- viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names)
-
- # Add 3D position coordinates if requested
- if show_coordinates and bboxes_2d is not None:
- for i, det in enumerate(detections):
- if det.bbox and det.bbox.center and i < len(bboxes_2d):
- position = det.bbox.center.position
- bbox = bboxes_2d[i]
-
- pos_xyz = np.array([position.x, position.y, position.z])
-
- # Get bounding box coordinates
- _x1, y1, x2, _y2 = map(int, bbox)
-
- # Add position text next to bounding box (top-right corner)
- pos_text = f"({pos_xyz[0]:.2f}, {pos_xyz[1]:.2f}, {pos_xyz[2]:.2f})"
- text_x = x2 + 5 # Right edge of bbox + small offset
- text_y = y1 + 15 # Top edge of bbox + small offset
-
- # Add background rectangle for better readability
- text_size = cv2.getTextSize(pos_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0]
- cv2.rectangle(
- viz,
- (text_x - 2, text_y - text_size[1] - 2),
- (text_x + text_size[0] + 2, text_y + 2),
- (0, 0, 0),
- -1,
- )
-
- cv2.putText(
- viz,
- pos_text,
- (text_x, text_y),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.4,
- (255, 255, 255),
- 1,
- )
-
- return viz # type: ignore[no-any-return]
diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py
index ee7512baba..97c96f8180 100644
--- a/dimos/mapping/costmapper.py
+++ b/dimos/mapping/costmapper.py
@@ -13,8 +13,6 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
-import queue
-import threading
import time
from reactivex import operators as ops
@@ -50,14 +48,14 @@ class CostMapper(Module):
global_map: In[PointCloud2]
global_costmap: Out[OccupancyGrid]
- # Background Rerun logging (decouples viz from data pipeline)
- _rerun_queue: queue.Queue[tuple[OccupancyGrid, float, float] | None]
- _rerun_thread: threading.Thread | None = None
-
@classmethod
def rerun_views(cls): # type: ignore[no-untyped-def]
"""Return Rerun view blueprints for costmap visualization."""
return [
+ rrb.Spatial2DView(
+ name="Costmap",
+ origin="world/nav/costmap/image",
+ ),
rrb.TimeSeriesView(
name="Costmap (ms)",
origin="/metrics/costmap",
@@ -68,62 +66,47 @@ def rerun_views(cls): # type: ignore[no-untyped-def]
def __init__(self, global_config: GlobalConfig | None = None, **kwargs: object) -> None:
super().__init__(**kwargs)
self._global_config = global_config or GlobalConfig()
- self._rerun_queue = queue.Queue(maxsize=2)
- def _rerun_worker(self) -> None:
- """Background thread: pull from queue and log to Rerun (non-blocking)."""
- while True:
- try:
- item = self._rerun_queue.get(timeout=1.0)
- if item is None: # Shutdown signal
- break
+ @rpc
+ def start(self) -> None:
+ super().start()
+
+ # Only start Rerun logging if Rerun backend is selected
+ if self._global_config.viewer_backend.startswith("rerun"):
+ connect_rerun(global_config=self._global_config)
+ logger.info("CostMapper: Rerun logging enabled (sync)")
- grid, calc_time_ms, rx_monotonic = item
+ def _publish_costmap(grid: OccupancyGrid, calc_time_ms: float, rx_monotonic: float) -> None:
+ # Publish to downstream first.
+ self.global_costmap.publish(grid)
- # Generate mesh + log to Rerun (blocks in background, not on data path)
+ # Synchronous Rerun logging (no queues/threads).
+ if self._global_config.viewer_backend.startswith("rerun"):
try:
- # 3D floor overlay (expensive mesh generation)
+ # 2D image panel
+ rr.log(
+ "world/nav/costmap/image",
+ grid.to_rerun(
+ mode="image",
+ colormap="RdBu_r",
+ ),
+ )
+
+ # 3D floor overlay (mesh)
rr.log(
"world/nav/costmap/floor",
grid.to_rerun(
mode="mesh",
- colormap=None, # Uses Foxglove-style colors (blue-purple free, black occupied)
- z_offset=0.05, # 5cm above floor to avoid z-fighting
+ colormap=None, # grayscale / foxglove-style
+ z_offset=0.07,
),
)
- # Log timing metrics
rr.log("metrics/costmap/calc_ms", rr.Scalars(calc_time_ms))
latency_ms = (time.monotonic() - rx_monotonic) * 1000
rr.log("metrics/costmap/latency_ms", rr.Scalars(latency_ms))
except Exception as e:
logger.warning(f"Rerun logging error: {e}")
- except queue.Empty:
- continue
-
- @rpc
- def start(self) -> None:
- super().start()
-
- # Only start Rerun logging if Rerun backend is selected
- if self._global_config.viewer_backend.startswith("rerun"):
- connect_rerun(global_config=self._global_config)
-
- # Start background Rerun logging thread
- self._rerun_thread = threading.Thread(target=self._rerun_worker, daemon=True)
- self._rerun_thread.start()
- logger.info("CostMapper: started async Rerun logging thread")
-
- def _publish_costmap(grid: OccupancyGrid, calc_time_ms: float, rx_monotonic: float) -> None:
- # Publish to downstream FIRST (fast, not blocked by Rerun)
- self.global_costmap.publish(grid)
-
- # Queue for async Rerun logging (non-blocking, drops if queue full)
- if self._rerun_thread and self._rerun_thread.is_alive():
- try:
- self._rerun_queue.put_nowait((grid, calc_time_ms, rx_monotonic))
- except queue.Full:
- pass # Drop viz frame, data pipeline continues
def _calculate_and_time(
msg: PointCloud2,
@@ -142,11 +125,6 @@ def _calculate_and_time(
@rpc
def stop(self) -> None:
- # Shutdown background Rerun thread
- if self._rerun_thread and self._rerun_thread.is_alive():
- self._rerun_queue.put(None) # Shutdown signal
- self._rerun_thread.join(timeout=2.0)
-
super().stop()
# @timed() # TODO: fix thread leak in timed decorator
diff --git a/dimos/mapping/occupancy/test_extrude_occupancy.py b/dimos/mapping/occupancy/test_extrude_occupancy.py
index 81caba7c8d..88f05d7780 100644
--- a/dimos/mapping/occupancy/test_extrude_occupancy.py
+++ b/dimos/mapping/occupancy/test_extrude_occupancy.py
@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
+
from dimos.mapping.occupancy.extrude_occupancy import generate_mujoco_scene
from dimos.utils.data import get_data
+@pytest.mark.integration
def test_generate_mujoco_scene(occupancy) -> None:
with open(get_data("expected_occupancy_scene.xml")) as f:
expected = f.read()
diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py
index a36dc9bc17..6570d9ba33 100644
--- a/dimos/mapping/voxels.py
+++ b/dimos/mapping/voxels.py
@@ -13,8 +13,6 @@
# limitations under the License.
from dataclasses import dataclass
-import queue
-import threading
import time
import numpy as np
@@ -31,7 +29,6 @@
from dimos.core.module import ModuleConfig
from dimos.dashboard.rerun_init import connect_rerun
from dimos.msgs.sensor_msgs import PointCloud2
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
from dimos.utils.decorators import simple_mcache
from dimos.utils.logging_config import setup_logger
from dimos.utils.reactive import backpressure
@@ -54,7 +51,7 @@ class VoxelGridMapper(Module):
default_config = Config
config: Config
- lidar: In[LidarMessage]
+ lidar: In[PointCloud2]
global_map: Out[PointCloud2]
@classmethod
@@ -106,33 +103,6 @@ def __init__(self, global_config: GlobalConfig | None = None, **kwargs: object)
# Monotonic timestamp of last received frame (for accurate latency in replay)
self._latest_frame_rx_monotonic: float | None = None
- # Background Rerun logging (decouples viz from data pipeline)
- self._rerun_queue: queue.Queue[PointCloud2 | None] = queue.Queue(maxsize=2)
- self._rerun_thread: threading.Thread | None = None
-
- def _rerun_worker(self) -> None:
- """Background thread: pull from queue and log to Rerun (non-blocking)."""
- while True:
- try:
- pc = self._rerun_queue.get(timeout=1.0)
- if pc is None: # Shutdown signal
- break
-
- # Log to Rerun (blocks in background, doesn't affect data pipeline)
- try:
- rr.log(
- "world/map",
- pc.to_rerun(
- mode="boxes",
- size=self.config.voxel_size,
- colormap="turbo",
- ),
- )
- except Exception as e:
- logger.warning(f"Rerun logging error: {e}")
- except queue.Empty:
- continue
-
@rpc
def start(self) -> None:
super().start()
@@ -140,11 +110,7 @@ def start(self) -> None:
# Only start Rerun logging if Rerun backend is selected
if self._global_config.viewer_backend.startswith("rerun"):
connect_rerun(global_config=self._global_config)
-
- # Start background Rerun logging thread (decouples viz from data pipeline)
- self._rerun_thread = threading.Thread(target=self._rerun_worker, daemon=True)
- self._rerun_thread.start()
- logger.info("VoxelGridMapper: started async Rerun logging thread")
+ logger.info("VoxelGridMapper: Rerun logging enabled (sync)")
# Subject to trigger publishing, with backpressure to drop if busy
self._publish_trigger: Subject[None] = Subject()
@@ -167,14 +133,9 @@ def start(self) -> None:
@rpc
def stop(self) -> None:
- # Shutdown background Rerun thread
- if self._rerun_thread and self._rerun_thread.is_alive():
- self._rerun_queue.put(None) # Shutdown signal
- self._rerun_thread.join(timeout=2.0)
-
super().stop()
- def _on_frame(self, frame: LidarMessage) -> None:
+ def _on_frame(self, frame: PointCloud2) -> None:
# Track receipt time with monotonic clock (works correctly in replay)
self._latest_frame_rx_monotonic = time.monotonic()
self.add_frame(frame)
@@ -196,12 +157,19 @@ def publish_global_map(self) -> None:
t2 = time.perf_counter()
self.global_map.publish(pc)
publish_ms = (time.perf_counter() - t2) * 1000
-
- # 3. Queue for async Rerun logging (non-blocking, drops if queue full)
- try:
- self._rerun_queue.put_nowait(pc)
- except queue.Full:
- pass # Drop viz frame, data pipeline continues
+ # 3. Synchronous Rerun logging (no queues/threads).
+ if self._global_config.viewer_backend.startswith("rerun"):
+ try:
+ rr.log(
+ "world/map",
+ pc.to_rerun(
+ mode="boxes",
+ size=self.config.voxel_size,
+ colormap="turbo",
+ ),
+ )
+ except Exception as e:
+ logger.warning(f"Rerun logging error: {e}")
# Log detailed timing breakdown to Rerun
total_ms = (time.perf_counter() - start_total) * 1000
diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py
index 0769fc150d..76bb377869 100644
--- a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py
+++ b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py
@@ -3,7 +3,6 @@
import os
from contact_graspnet_pytorch import config_utils # type: ignore[import-not-found]
-from contact_graspnet_pytorch.checkpoints import CheckpointIO # type: ignore[import-not-found]
from contact_graspnet_pytorch.contact_grasp_estimator import ( # type: ignore[import-not-found]
GraspEstimator,
)
@@ -11,6 +10,7 @@
load_available_input_data,
)
import numpy as np
+import torch
from dimos.utils.data import get_data
@@ -45,12 +45,9 @@ def inference(global_config, # type: ignore[no-untyped-def]
# Load the weights
model_checkpoint_dir = get_data(ckpt_dir)
- checkpoint_io = CheckpointIO(checkpoint_dir=model_checkpoint_dir, model=grasp_estimator.model)
- try:
- checkpoint_io.load('model.pt')
- except FileExistsError:
- print('No model checkpoint found')
-
+ checkpoint_path = os.path.join(model_checkpoint_dir, 'model.pt')
+ state_dict = torch.load(checkpoint_path, weights_only=False)
+ grasp_estimator.model.load_state_dict(state_dict['model'])
os.makedirs('results', exist_ok=True)
diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py
index 7964a24954..7ee0f49451 100644
--- a/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py
+++ b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py
@@ -13,6 +13,7 @@ def is_manipulation_installed() -> bool:
except ImportError:
return False
+@pytest.mark.integration
@pytest.mark.skipif(not is_manipulation_installed(),
reason="This test requires 'pip install .[manipulation]' to be run")
def test_contact_graspnet_inference() -> None:
diff --git a/dimos/models/segmentation/configs/edgetam.yaml b/dimos/models/segmentation/configs/edgetam.yaml
new file mode 100644
index 0000000000..6fe21c99df
--- /dev/null
+++ b/dimos/models/segmentation/configs/edgetam.yaml
@@ -0,0 +1,138 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.sam2_video_predictor.SAM2VideoPredictor
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.timm.TimmBackbone
+ name: repvit_m1.dist_in1k
+ features:
+ - layer0
+ - layer1
+ - layer2
+ - layer3
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [384, 192, 96, 48]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttentionv2
+ rope_theta: 10000.0
+ q_sizes: [64, 64]
+ k_sizes: [16, 16]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 2
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ spatial_perceiver:
+ _target_: sam2.modeling.perceiver.PerceiverResampler
+ depth: 2
+ dim: 64
+ dim_head: 64
+ heads: 1
+ ff_mult: 4
+ hidden_dropout_p: 0.
+ attention_dropout_p: 0.
+ pos_enc_at_key_value: true # implicit pos
+ concat_kv_latents: false
+ num_latents: 256
+ num_latents_2d: 256
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ use_self_attn: true
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: false
diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py
new file mode 100644
index 0000000000..ba351be130
--- /dev/null
+++ b/dimos/models/segmentation/edge_tam.py
@@ -0,0 +1,269 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Generator
+from contextlib import contextmanager
+import os
+from pathlib import Path
+import shutil
+import tempfile
+from typing import TYPE_CHECKING, Any, TypedDict
+
+import cv2
+from hydra.utils import instantiate # type: ignore[import-not-found]
+import numpy as np
+from numpy.typing import NDArray
+from omegaconf import OmegaConf # type: ignore[import-not-found]
+from PIL import Image as PILImage
+import torch
+
+from dimos.msgs.sensor_msgs import Image
+from dimos.perception.detection.detectors.types import Detector
+from dimos.perception.detection.type import ImageDetections2D
+from dimos.perception.detection.type.detection2d.seg import Detection2DSeg
+from dimos.utils.data import get_data
+from dimos.utils.logging_config import setup_logger
+
+if TYPE_CHECKING:
+ from sam2.sam2_video_predictor import SAM2VideoPredictor # type: ignore[import-untyped]
+
+os.environ['TQDM_DISABLE'] = '1'
+
+logger = setup_logger()
+
+
+class SAM2InferenceState(TypedDict):
+ images: list[torch.Tensor | None]
+ num_frames: int
+ cached_features: dict[int, Any]
+
+
+class EdgeTAMProcessor(Detector):
+ _predictor: "SAM2VideoPredictor"
+ _inference_state: SAM2InferenceState | None
+ _frame_count: int
+ _is_tracking: bool
+ _buffer_size: int
+
+ def __init__(
+ self,
+ ) -> None:
+ local_config_path = Path(__file__).parent / "configs" / "edgetam.yaml"
+
+ if not local_config_path.exists():
+ raise FileNotFoundError(f"EdgeTAM config not found at {local_config_path}")
+
+ if not torch.cuda.is_available():
+ raise RuntimeError("EdgeTAM requires a CUDA-capable GPU")
+
+ cfg = OmegaConf.load(local_config_path)
+
+ overrides = {
+ "model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability": True,
+ "model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta": 0.05,
+ "model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh": 0.98,
+ "model.binarize_mask_from_pts_for_mem_enc": True,
+ "model.fill_hole_area": 8,
+ }
+
+ for key, value in overrides.items():
+ OmegaConf.update(cfg, key, value)
+
+ if cfg.model._target_ != "sam2.sam2_video_predictor.SAM2VideoPredictor":
+ logger.warning(
+ f"Config target is {cfg.model._target_}, forcing SAM2VideoPredictor"
+ )
+ cfg.model._target_ = "sam2.sam2_video_predictor.SAM2VideoPredictor"
+
+ self._predictor = instantiate(cfg.model, _recursive_=True)
+
+ ckpt_path = str(get_data("models_edgetam") / "edgetam.pt")
+
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
+ missing_keys, unexpected_keys = self._predictor.load_state_dict(sd)
+ if missing_keys:
+ raise RuntimeError("Missing keys in checkpoint")
+ if unexpected_keys:
+ raise RuntimeError("Unexpected keys in checkpoint")
+
+ self._predictor = self._predictor.to("cuda")
+ self._predictor.eval()
+
+ self._inference_state = None
+ self._frame_count = 0
+ self._is_tracking = False
+ self._buffer_size = 100 # Keep last N frames in memory to avoid OOM
+
+ def _prepare_frame(self, image: Image) -> torch.Tensor:
+ """Prepare frame for SAM2 (resize, normalize, convert to tensor)."""
+
+ cv_image = image.to_opencv()
+ rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
+ pil_image = PILImage.fromarray(rgb_image)
+
+ img_np = np.array(
+ pil_image.resize((self._predictor.image_size, self._predictor.image_size))
+ )
+ img_np = img_np.astype(np.float32) / 255.0
+
+ img_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
+ img_std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
+ img_np -= img_mean
+ img_np /= img_std
+
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).float()
+ img_tensor = img_tensor.cuda()
+
+ return img_tensor
+
+ def init_track(
+ self,
+ image: Image,
+ points: NDArray[np.floating[Any]] | None = None,
+ labels: NDArray[np.integer[Any]] | None = None,
+ box: NDArray[np.floating[Any]] | None = None,
+ obj_id: int = 1,
+ ) -> ImageDetections2D:
+ """Initialize tracking with a prompt (points or box).
+
+ Args:
+ image: Initial frame to start tracking from
+ points: Point prompts for segmentation (Nx2 array of [x, y] coordinates)
+ labels: Labels for points (1 = foreground, 0 = background)
+ box: Bounding box prompt in [x1, y1, x2, y2] format
+ obj_id: Object ID for tracking
+
+ Returns:
+ ImageDetections2D with initial segmentation mask
+ """
+ if self._inference_state is not None:
+ self.stop()
+
+ self._frame_count = 0
+
+ with _temp_dir_context(image) as video_path:
+ self._inference_state = self._predictor.init_state(video_path=video_path)
+
+ self._predictor.reset_state(self._inference_state)
+
+ if torch.is_tensor(self._inference_state["images"]):
+ self._inference_state["images"] = [self._inference_state["images"][0]]
+
+ self._is_tracking = True
+
+ if points is not None:
+ points = points.astype(np.float32)
+ if labels is not None:
+ labels = labels.astype(np.int32)
+ if box is not None:
+ box = box.astype(np.float32)
+
+ with torch.no_grad():
+ _, out_obj_ids, out_mask_logits = self._predictor.add_new_points_or_box(
+ inference_state=self._inference_state,
+ frame_idx=0,
+ obj_id=obj_id,
+ points=points,
+ labels=labels,
+ box=box,
+ )
+
+ return self._process_results(image, out_obj_ids, out_mask_logits)
+
+ def process_image(self, image: Image) -> ImageDetections2D:
+ """Process a new video frame and propagate tracking.
+
+ Args:
+ image: New frame to process
+
+ Returns:
+ ImageDetections2D with tracked object segmentation masks
+ """
+ if not self._is_tracking or self._inference_state is None:
+ return ImageDetections2D(image=image)
+
+ self._frame_count += 1
+
+ # Append new frame to inference state
+ new_frame_tensor = self._prepare_frame(image)
+ self._inference_state["images"].append(new_frame_tensor)
+ self._inference_state["num_frames"] += 1
+
+ # Memory management
+ cached_features = self._inference_state["cached_features"]
+ if len(cached_features) > self._buffer_size:
+ oldest_frame = min(cached_features.keys())
+ if oldest_frame < self._frame_count - self._buffer_size:
+ del cached_features[oldest_frame]
+
+ if len(self._inference_state["images"]) > self._buffer_size + 10:
+ idx_to_drop = self._frame_count - self._buffer_size - 5
+ if idx_to_drop >= 0 and idx_to_drop < len(self._inference_state["images"]):
+ if self._inference_state["images"][idx_to_drop] is not None:
+ self._inference_state["images"][idx_to_drop] = None
+
+ detections: ImageDetections2D = ImageDetections2D(image=image)
+
+ with torch.no_grad():
+ for out_frame_idx, out_obj_ids, out_mask_logits in self._predictor.propagate_in_video(
+ self._inference_state, start_frame_idx=self._frame_count, max_frame_num_to_track=1
+ ):
+ if out_frame_idx == self._frame_count:
+ return self._process_results(image, out_obj_ids, out_mask_logits)
+
+ return detections
+
+ def _process_results(
+ self,
+ image: Image,
+ obj_ids: list[int],
+ mask_logits: torch.Tensor | NDArray[np.floating[Any]],
+ ) -> ImageDetections2D:
+ detections: ImageDetections2D = ImageDetections2D(image=image)
+
+ if len(obj_ids) == 0:
+ return detections
+
+ if isinstance(mask_logits, torch.Tensor):
+ mask_logits = mask_logits.cpu().numpy()
+
+ for i, obj_id in enumerate(obj_ids):
+ mask = mask_logits[i]
+ seg = Detection2DSeg.from_sam2_result(
+ mask=mask,
+ obj_id=obj_id,
+ image=image,
+ name="object",
+ )
+
+ if seg.is_valid():
+ detections.detections.append(seg)
+
+ return detections
+
+ def stop(self) -> None:
+ self._is_tracking = False
+ self._inference_state = None
+
+
+@contextmanager
+def _temp_dir_context(image: Image) -> Generator[str, None, None]:
+ path = tempfile.mkdtemp()
+
+ image.save(f"{path}/00000.jpg")
+
+ try:
+ yield path
+ finally:
+ shutil.rmtree(path)
diff --git a/dimos/models/vl/__init__.py b/dimos/models/vl/__init__.py
index 6f120f9141..e4bb68e03c 100644
--- a/dimos/models/vl/__init__.py
+++ b/dimos/models/vl/__init__.py
@@ -2,6 +2,7 @@
from dimos.models.vl.florence import Florence2Model
from dimos.models.vl.moondream import MoondreamVlModel
from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel
+from dimos.models.vl.openai import OpenAIVlModel
from dimos.models.vl.qwen import QwenVlModel
__all__ = [
@@ -9,6 +10,7 @@
"Florence2Model",
"MoondreamHostedVlModel",
"MoondreamVlModel",
+ "OpenAIVlModel",
"QwenVlModel",
"VlModel",
]
diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py
index c28a12363f..fc1f8b7a17 100644
--- a/dimos/models/vl/moondream_hosted.py
+++ b/dimos/models/vl/moondream_hosted.py
@@ -8,7 +8,7 @@
from dimos.models.vl.base import VlModel
from dimos.msgs.sensor_msgs import Image
-from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D
+from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D
class MoondreamHostedVlModel(VlModel):
@@ -107,29 +107,41 @@ def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetection
return image_detections
- def point(self, image: Image, query: str) -> list[tuple[float, float]]:
- """Get coordinates of specific objects in an image.
+ def query_points(
+ self, image: Image, query: str, **kwargs: object
+ ) -> ImageDetections2D[Detection2DPoint]:
+ """Detect point locations using Moondream's hosted point method.
Args:
image: Input image
- query: Object query
+ query: Object query (e.g., "person's head", "center of the ball")
Returns:
- List of (x, y) pixel coordinates
+ ImageDetections2D containing detected points
"""
pil_image = self._to_pil_image(image)
result = self._client.point(pil_image, query)
- points = result.get("points", [])
- pixel_points = []
+ image_detections: ImageDetections2D[Detection2DPoint] = ImageDetections2D(image)
height, width = image.height, image.width
- for p in points:
- x_norm = p.get("x", 0.0)
- y_norm = p.get("y", 0.0)
- pixel_points.append((x_norm * width, y_norm * height))
+ for track_id, point in enumerate(result.get("points", [])):
+ x = point.get("x", 0.0) * width
+ y = point.get("y", 0.0) * height
- return pixel_points
+ detection = Detection2DPoint(
+ x=x,
+ y=y,
+ name=query,
+ ts=image.ts,
+ image=image,
+ track_id=track_id,
+ )
+
+ if detection.is_valid():
+ image_detections.detections.append(detection)
+
+ return image_detections
def stop(self) -> None:
pass
diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py
new file mode 100644
index 0000000000..f596f1ee1e
--- /dev/null
+++ b/dimos/models/vl/openai.py
@@ -0,0 +1,106 @@
+from dataclasses import dataclass
+from functools import cached_property
+import os
+from typing import Any
+
+import numpy as np
+from openai import OpenAI
+
+from dimos.models.vl.base import VlModel, VlModelConfig
+from dimos.msgs.sensor_msgs import Image
+from dimos.utils.logging_config import setup_logger
+
+logger = setup_logger()
+
+
+@dataclass
+class OpenAIVlModelConfig(VlModelConfig):
+ model_name: str = "gpt-4o-mini"
+ api_key: str | None = None
+
+
+class OpenAIVlModel(VlModel):
+ default_config = OpenAIVlModelConfig
+ config: OpenAIVlModelConfig
+
+ @cached_property
+ def _client(self) -> OpenAI:
+ api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError(
+ "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable"
+ )
+
+ return OpenAI(api_key=api_key)
+
+ def query(self, image: Image | np.ndarray, query: str, response_format: dict | None = None, **kwargs) -> str: # type: ignore[override, type-arg, no-untyped-def]
+ if isinstance(image, np.ndarray):
+ import warnings
+
+ warnings.warn(
+ "OpenAIVlModel.query should receive standard dimos Image type, not a numpy array",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ image = Image.from_numpy(image)
+
+ # Apply auto_resize if configured
+ image, _ = self._prepare_image(image)
+
+ img_base64 = image.to_base64()
+
+ api_kwargs: dict[str, Any] = {
+ "model": self.config.model_name,
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{img_base64}"},
+ },
+ {"type": "text", "text": query},
+ ],
+ }
+ ],
+ }
+
+ if response_format:
+ api_kwargs["response_format"] = response_format
+
+ response = self._client.chat.completions.create(**api_kwargs)
+
+ return response.choices[0].message.content # type: ignore[return-value,no-any-return]
+
+ def query_batch(
+ self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any
+ ) -> list[str]: # type: ignore[override]
+ """Query VLM with multiple images using a single API call."""
+ if not images:
+ return []
+
+ content: list[dict[str, Any]] = [
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"},
+ }
+ for img in images
+ ]
+ content.append({"type": "text", "text": query})
+
+ messages = [{"role": "user", "content": content}]
+ api_kwargs: dict[str, Any] = {"model": self.config.model_name, "messages": messages}
+ if response_format:
+ api_kwargs["response_format"] = response_format
+
+ response = self._client.chat.completions.create(**api_kwargs)
+ response_text = response.choices[0].message.content or ""
+ # Return one response per image (same response since API analyzes all images together)
+ return [response_text] * len(images)
+
+ def stop(self) -> None:
+ """Release the OpenAI client."""
+ if "_client" in self.__dict__:
+ del self.__dict__["_client"]
+
diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py
index b1d3d6f036..93b31bf74c 100644
--- a/dimos/models/vl/qwen.py
+++ b/dimos/models/vl/qwen.py
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from functools import cached_property
import os
+from typing import Any
import numpy as np
from openai import OpenAI
@@ -69,6 +70,32 @@ def query(self, image: Image | np.ndarray, query: str) -> str: # type: ignore[o
return response.choices[0].message.content # type: ignore[return-value]
+ def query_batch(
+ self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any
+ ) -> list[str]: # type: ignore[override]
+ """Query VLM with multiple images using a single API call."""
+ if not images:
+ return []
+
+ content: list[dict[str, Any]] = [
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"},
+ }
+ for img in images
+ ]
+ content.append({"type": "text", "text": query})
+
+ messages = [{"role": "user", "content": content}]
+ api_kwargs: dict[str, Any] = {"model": self.config.model_name, "messages": messages}
+ if response_format:
+ api_kwargs["response_format"] = response_format
+
+ response = self._client.chat.completions.create(**api_kwargs)
+ response_text = response.choices[0].message.content or ""
+ # Return one response per image (same response since API analyzes all images together)
+ return [response_text] * len(images)
+
def stop(self) -> None:
"""Release the OpenAI client."""
if "_client" in self.__dict__:
diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py
index 1bf20eb680..741e0dede2 100644
--- a/dimos/models/vl/test_vlm.py
+++ b/dimos/models/vl/test_vlm.py
@@ -1,3 +1,4 @@
+import os
import time
from typing import TYPE_CHECKING
@@ -8,6 +9,7 @@
from dimos.core import LCMTransport
from dimos.models.vl.moondream import MoondreamVlModel
+from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel
from dimos.models.vl.qwen import QwenVlModel
from dimos.msgs.sensor_msgs import Image
from dimos.perception.detection.type import ImageDetections2D
@@ -26,11 +28,15 @@
"model_class,model_name",
[
(MoondreamVlModel, "Moondream"),
+ (MoondreamHostedVlModel, "Moondream Hosted"),
(QwenVlModel, "Qwen"),
],
)
@pytest.mark.gpu
def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> None:
+ if model_class is MoondreamHostedVlModel and 'MOONDREAM_API_KEY' not in os.environ:
+ pytest.skip("Need MOONDREAM_API_KEY to run")
+
image = Image.from_file(get_data("cafe.jpg")).to_rgb()
print(f"Testing {model_name}")
@@ -94,12 +100,17 @@ def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> N
"model_class,model_name",
[
(MoondreamVlModel, "Moondream"),
+ (MoondreamHostedVlModel, "Moondream Hosted"),
(QwenVlModel, "Qwen"),
],
)
@pytest.mark.gpu
def test_vlm_point_detections(model_class: "type[VlModel]", model_name: str) -> None:
"""Test VLM point detection capabilities."""
+
+ if model_class is MoondreamHostedVlModel and 'MOONDREAM_API_KEY' not in os.environ:
+ pytest.skip("Need MOONDREAM_API_KEY to run")
+
image = Image.from_file(get_data("cafe.jpg")).to_rgb()
print(f"Testing {model_name} point detection")
diff --git a/dimos/msgs/__init__.py b/dimos/msgs/__init__.py
index e69de29bb2..b2bcabab01 100644
--- a/dimos/msgs/__init__.py
+++ b/dimos/msgs/__init__.py
@@ -0,0 +1,3 @@
+from dimos.msgs.protocol import DimosMsg
+
+__all__ = ["DimosMsg"]
diff --git a/dimos/msgs/geometry_msgs/Quaternion.py b/dimos/msgs/geometry_msgs/Quaternion.py
index d19436d441..02c9592ea6 100644
--- a/dimos/msgs/geometry_msgs/Quaternion.py
+++ b/dimos/msgs/geometry_msgs/Quaternion.py
@@ -138,6 +138,20 @@ def from_euler(cls, vector: Vector3) -> Quaternion:
return cls(x, y, z, w)
+ @classmethod
+ def from_rotation_matrix(cls, matrix: np.ndarray) -> Quaternion: # type: ignore[type-arg]
+ """Convert a 3x3 rotation matrix to quaternion.
+
+ Args:
+ matrix: 3x3 rotation matrix (numpy array)
+
+ Returns:
+ Quaternion representation
+ """
+ rotation = R.from_matrix(matrix)
+ quat = rotation.as_quat() # Returns [x, y, z, w]
+ return cls(quat[0], quat[1], quat[2], quat[3])
+
def to_euler(self) -> Vector3:
"""Convert quaternion to Euler angles (roll, pitch, yaw) in radians.
diff --git a/dimos/msgs/geometry_msgs/Vector3.py b/dimos/msgs/geometry_msgs/Vector3.py
index 907079d5c1..8be8cdf2ae 100644
--- a/dimos/msgs/geometry_msgs/Vector3.py
+++ b/dimos/msgs/geometry_msgs/Vector3.py
@@ -15,11 +15,10 @@
from __future__ import annotations
from collections.abc import Sequence
-from typing import TypeAlias
+from typing import Any, TypeAlias
from dimos_lcm.geometry_msgs import Vector3 as LCMVector3
import numpy as np
-from plum import dispatch
# Types that can be converted to/from Vector
VectorConvertable: TypeAlias = Sequence[int | float] | LCMVector3 | np.ndarray # type: ignore[type-arg]
@@ -43,65 +42,74 @@ class Vector3(LCMVector3): # type: ignore[misc]
x: float = 0.0
y: float = 0.0
z: float = 0.0
- msg_name = "geometry_msgs.Vector3"
-
- @dispatch
- def __init__(self) -> None:
- """Initialize a zero 3D vector."""
- self.x = 0.0
- self.y = 0.0
- self.z = 0.0
-
- @dispatch # type: ignore[no-redef]
- def __init__(self, x: int | float) -> None:
- """Initialize a 3D vector from a single numeric value (x, 0, 0)."""
- self.x = float(x)
- self.y = 0.0
- self.z = 0.0
-
- @dispatch # type: ignore[no-redef]
- def __init__(self, x: int | float, y: int | float) -> None:
- """Initialize a 3D vector from x, y components (z=0)."""
- self.x = float(x)
- self.y = float(y)
- self.z = 0.0
-
- @dispatch # type: ignore[no-redef]
- def __init__(self, x: int | float, y: int | float, z: int | float) -> None:
- """Initialize a 3D vector from x, y, z components."""
- self.x = float(x)
- self.y = float(y)
- self.z = float(z)
-
- @dispatch # type: ignore[no-redef]
- def __init__(self, sequence: Sequence[int | float]) -> None:
- """Initialize from a sequence (list, tuple) of numbers, ensuring 3D."""
- data = _ensure_3d(np.array(sequence, dtype=float))
- self.x = float(data[0])
- self.y = float(data[1])
- self.z = float(data[2])
-
- @dispatch # type: ignore[no-redef]
- def __init__(self, array: np.ndarray) -> None: # type: ignore[type-arg]
- """Initialize from a numpy array, ensuring 3D."""
- data = _ensure_3d(np.array(array, dtype=float))
- self.x = float(data[0])
- self.y = float(data[1])
- self.z = float(data[2])
-
- @dispatch # type: ignore[no-redef]
- def __init__(self, vector: Vector3) -> None:
- """Initialize from another Vector3 (copy constructor)."""
- self.x = vector.x
- self.y = vector.y
- self.z = vector.z
-
- @dispatch # type: ignore[no-redef]
- def __init__(self, lcm_vector: LCMVector3) -> None:
- """Initialize from an LCM Vector3."""
- self.x = float(lcm_vector.x)
- self.y = float(lcm_vector.y)
- self.z = float(lcm_vector.z)
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ """Initialize a 3D vector.
+
+ Supported forms:
+ Vector3() # zero vector
+ Vector3(x) # (x, 0, 0)
+ Vector3(x, y) # (x, y, 0)
+ Vector3(x, y, z) # (x, y, z)
+ Vector3(x=1, y=2, z=3) # keyword args
+ Vector3([x, y, z]) # sequence
+ Vector3(np.array([x, y, z])) # numpy array
+ Vector3(other_vector3) # copy constructor
+ Vector3(lcm_vector3) # from LCM message
+ """
+ if kwargs and not args:
+ # Keyword arguments: Vector3(x=1, y=2, z=3)
+ self.x = float(kwargs.get("x", 0.0))
+ self.y = float(kwargs.get("y", 0.0))
+ self.z = float(kwargs.get("z", 0.0))
+ elif not args:
+ # No arguments: zero vector
+ self.x = 0.0
+ self.y = 0.0
+ self.z = 0.0
+ elif len(args) == 1:
+ arg = args[0]
+ if isinstance(arg, Vector3):
+ # Copy constructor
+ self.x = arg.x
+ self.y = arg.y
+ self.z = arg.z
+ elif isinstance(arg, LCMVector3):
+ # From LCM Vector3
+ self.x = float(arg.x)
+ self.y = float(arg.y)
+ self.z = float(arg.z)
+ elif isinstance(arg, np.ndarray):
+ # From numpy array
+ data = _ensure_3d(np.array(arg, dtype=float))
+ self.x = float(data[0])
+ self.y = float(data[1])
+ self.z = float(data[2])
+ elif isinstance(arg, (list, tuple)):
+ # From sequence
+ data = _ensure_3d(np.array(arg, dtype=float))
+ self.x = float(data[0])
+ self.y = float(data[1])
+ self.z = float(data[2])
+ elif isinstance(arg, (int, float)):
+ # Single numeric value: (x, 0, 0)
+ self.x = float(arg)
+ self.y = 0.0
+ self.z = 0.0
+ else:
+ raise TypeError(f"Cannot initialize Vector3 from {type(arg)}")
+ elif len(args) == 2:
+ # Two numeric values: (x, y, 0)
+ self.x = float(args[0])
+ self.y = float(args[1])
+ self.z = 0.0
+ elif len(args) == 3:
+ # Three numeric values: (x, y, z)
+ self.x = float(args[0])
+ self.y = float(args[1])
+ self.z = float(args[2])
+ else:
+ raise TypeError(f"Vector3 takes at most 3 positional arguments ({len(args)} given)")
@property
def as_tuple(self) -> tuple[float, float, float]:
@@ -124,7 +132,7 @@ def data(self) -> np.ndarray: # type: ignore[type-arg]
"""Get the underlying numpy array."""
return np.array([self.x, self.y, self.z], dtype=float)
- def __getitem__(self, idx: int): # type: ignore[no-untyped-def]
+ def __getitem__(self, idx: int) -> float:
if idx == 0:
return self.x
elif idx == 1:
@@ -138,7 +146,7 @@ def __repr__(self) -> str:
return f"Vector({self.data})"
def __str__(self) -> str:
- def getArrow(): # type: ignore[no-untyped-def]
+ def getArrow() -> str:
repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"]
if self.x == 0 and self.y == 0:
@@ -151,21 +159,21 @@ def getArrow(): # type: ignore[no-untyped-def]
# Get directional arrow symbol
return repr[dir_index]
- return f"{getArrow()} Vector {self.__repr__()}" # type: ignore[no-untyped-call]
+ return f"{getArrow()} Vector {self.__repr__()}"
- def agent_encode(self) -> dict: # type: ignore[type-arg]
+ def agent_encode(self) -> dict[str, float]:
"""Encode the vector for agent communication."""
return {"x": self.x, "y": self.y, "z": self.z}
- def serialize(self) -> dict: # type: ignore[type-arg]
+ def serialize(self) -> dict[str, Any]:
"""Serialize the vector to a tuple."""
return {"type": "vector", "c": (self.x, self.y, self.z)}
- def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
+ def __eq__(self, other: object) -> bool:
"""Check if two vectors are equal using numpy's allclose for floating point comparison."""
if not isinstance(other, Vector3):
return False
- return np.allclose([self.x, self.y, self.z], [other.x, other.y, other.z])
+ return bool(np.allclose([self.x, self.y, self.z], [other.x, other.y, other.z]))
def __add__(self, other: VectorConvertable | Vector3) -> Vector3:
other_vector: Vector3 = to_vector(other)
@@ -194,7 +202,7 @@ def __neg__(self) -> Vector3:
def dot(self, other: VectorConvertable | Vector3) -> float:
"""Compute dot product."""
other_vector = to_vector(other)
- return self.x * other_vector.x + self.y * other_vector.y + self.z * other_vector.z # type: ignore[no-any-return]
+ return float(self.x * other_vector.x + self.y * other_vector.y + self.z * other_vector.z)
def cross(self, other: VectorConvertable | Vector3) -> Vector3:
"""Compute cross product (3D vectors only)."""
@@ -321,13 +329,13 @@ def is_zero(self) -> bool:
Returns:
True if all components are zero, False otherwise
"""
- return np.allclose([self.x, self.y, self.z], 0.0)
+ return bool(np.allclose([self.x, self.y, self.z], 0.0))
@property
- def quaternion(self): # type: ignore[no-untyped-def]
- return self.to_quaternion() # type: ignore[no-untyped-call]
+ def quaternion(self) -> Quaternion: # type: ignore[name-defined]
+ return self.to_quaternion()
- def to_quaternion(self): # type: ignore[no-untyped-def]
+ def to_quaternion(self) -> Quaternion: # type: ignore[name-defined]
"""Convert Vector3 representing Euler angles (roll, pitch, yaw) to a Quaternion.
Assumes this Vector3 contains Euler angles in radians:
@@ -377,73 +385,43 @@ def __bool__(self) -> bool:
return not self.is_zero()
-@dispatch
-def to_numpy(value: Vector3) -> np.ndarray: # type: ignore[type-arg]
- """Convert a Vector3 to a numpy array."""
- return value.to_numpy()
-
-
-@dispatch # type: ignore[no-redef]
-def to_numpy(value: np.ndarray) -> np.ndarray: # type: ignore[type-arg]
- """Pass through numpy arrays."""
- return value
-
-
-@dispatch # type: ignore[no-redef]
-def to_numpy(value: Sequence[int | float]) -> np.ndarray: # type: ignore[type-arg]
- """Convert a sequence to a numpy array."""
- return np.array(value, dtype=float)
-
-
-@dispatch
-def to_vector(value: Vector3) -> Vector3:
- """Pass through Vector3 objects."""
- return value
+def to_numpy(value: Vector3 | np.ndarray | Sequence[int | float]) -> np.ndarray: # type: ignore[type-arg]
+ """Convert a value to a numpy array."""
+ if isinstance(value, Vector3):
+ return value.to_numpy()
+ elif isinstance(value, np.ndarray):
+ return value
+ else:
+ return np.array(value, dtype=float)
-@dispatch # type: ignore[no-redef]
def to_vector(value: VectorConvertable | Vector3) -> Vector3:
"""Convert a vector-compatible value to a Vector3 object."""
+ if isinstance(value, Vector3):
+ return value
return Vector3(value)
-@dispatch
-def to_tuple(value: Vector3) -> tuple[float, float, float]:
- """Convert a Vector3 to a tuple."""
- return value.to_tuple()
-
-
-@dispatch # type: ignore[no-redef]
-def to_tuple(value: np.ndarray) -> tuple[float, ...]: # type: ignore[type-arg]
- """Convert a numpy array to a tuple."""
- return tuple(value.tolist())
-
-
-@dispatch # type: ignore[no-redef]
-def to_tuple(value: Sequence[int | float]) -> tuple[float, ...]:
- """Convert a sequence to a tuple."""
- if isinstance(value, tuple):
+def to_tuple(value: Vector3 | np.ndarray | Sequence[int | float]) -> tuple[float, ...]: # type: ignore[type-arg]
+ """Convert a value to a tuple."""
+ if isinstance(value, Vector3):
+ return value.to_tuple()
+ elif isinstance(value, np.ndarray):
+ return tuple(value.tolist())
+ elif isinstance(value, tuple):
return value
else:
return tuple(value)
-@dispatch
-def to_list(value: Vector3) -> list[float]:
- """Convert a Vector3 to a list."""
- return value.to_list()
-
-
-@dispatch # type: ignore[no-redef]
-def to_list(value: np.ndarray) -> list[float]: # type: ignore[type-arg]
- """Convert a numpy array to a list."""
- return value.tolist()
-
-
-@dispatch # type: ignore[no-redef]
-def to_list(value: Sequence[int | float]) -> list[float]:
- """Convert a sequence to a list."""
- if isinstance(value, list):
+def to_list(value: Vector3 | np.ndarray | Sequence[int | float]) -> list[float]: # type: ignore[type-arg]
+ """Convert a value to a list."""
+ if isinstance(value, Vector3):
+ return value.to_list()
+ elif isinstance(value, np.ndarray):
+ result: list[float] = value.tolist()
+ return result
+ elif isinstance(value, list):
return value
else:
return list(value)
diff --git a/dimos/msgs/protocol.py b/dimos/msgs/protocol.py
new file mode 100644
index 0000000000..38d7ca57e2
--- /dev/null
+++ b/dimos/msgs/protocol.py
@@ -0,0 +1,31 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Protocol, runtime_checkable
+
+
+@runtime_checkable
+class DimosMsg(Protocol):
+ """Protocol for dimos message types (LCM-based messages from dimos.msgs)."""
+
+ msg_name: str
+
+ @classmethod
+ def lcm_decode(cls, data: bytes) -> "DimosMsg":
+ """Decode bytes into a message instance."""
+ ...
+
+ def lcm_encode(self) -> bytes:
+ """Encode this message instance into bytes."""
+ ...
diff --git a/dimos/msgs/sensor_msgs/CameraInfo.py b/dimos/msgs/sensor_msgs/CameraInfo.py
index b6f85dbaca..855276b4e6 100644
--- a/dimos/msgs/sensor_msgs/CameraInfo.py
+++ b/dimos/msgs/sensor_msgs/CameraInfo.py
@@ -20,7 +20,6 @@
from dimos_lcm.sensor_msgs import CameraInfo as LCMCameraInfo
from dimos_lcm.std_msgs.Header import Header
import numpy as np
-import rerun as rr
# Import ROS types
try:
@@ -405,6 +404,8 @@ def to_rerun(self, image_plane_distance: float = 0.5): # type: ignore[no-untype
Returns:
rr.Pinhole archetype for logging to Rerun
"""
+ import rerun as rr
+
# Extract intrinsics from K matrix
# K = [fx, 0, cx, 0, fy, cy, 0, 0, 1]
fx, fy = self.K[0], self.K[4]
diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py
index 26ed141867..de3e7abeca 100644
--- a/dimos/msgs/sensor_msgs/Image.py
+++ b/dimos/msgs/sensor_msgs/Image.py
@@ -310,6 +310,10 @@ def to_cupy(self) -> Image:
def to_opencv(self) -> np.ndarray: # type: ignore[type-arg]
return self._impl.to_opencv()
+ def as_numpy(self) -> np.ndarray: # type: ignore[type-arg]
+ """Get image data as numpy array in RGB format."""
+ return np.asarray(self.data)
+
def to_rgb(self) -> Image:
return Image(self._impl.to_rgb())
@@ -352,6 +356,20 @@ def sharpness(self) -> float:
"""Return sharpness score."""
return self._impl.sharpness()
+ def to_depth_meters(self) -> Image:
+ """Return a depth image normalized to meters as float32."""
+ depth_cv = self.to_opencv()
+ fmt = self.format
+
+ if fmt == ImageFormat.DEPTH16:
+ depth_cv = depth_cv.astype(np.float32) / 1000.0
+ fmt = ImageFormat.DEPTH
+ elif depth_cv.dtype != np.float32:
+ depth_cv = depth_cv.astype(np.float32)
+ fmt = ImageFormat.DEPTH if fmt == ImageFormat.DEPTH else fmt
+
+ return Image.from_numpy(depth_cv, format=fmt, frame_id=self.frame_id, ts=self.ts)
+
def save(self, filepath: str) -> bool:
return self._impl.save(filepath)
diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py
index 1f048d5eaf..d68c62f51d 100644
--- a/dimos/msgs/sensor_msgs/PointCloud2.py
+++ b/dimos/msgs/sensor_msgs/PointCloud2.py
@@ -23,11 +23,9 @@
)
from dimos_lcm.sensor_msgs.PointField import PointField # type: ignore[import-untyped]
from dimos_lcm.std_msgs.Header import Header # type: ignore[import-untyped]
-import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d # type: ignore[import-untyped]
import open3d.core as o3c # type: ignore[import-untyped]
-import rerun as rr
from dimos.msgs.geometry_msgs import Transform, Vector3
@@ -55,6 +53,8 @@
@functools.lru_cache(maxsize=16)
def _get_matplotlib_cmap(name: str): # type: ignore[no-untyped-def]
"""Get a matplotlib colormap by name (cached for performance)."""
+ import matplotlib.pyplot as plt
+
return plt.get_cmap(name)
@@ -354,6 +354,8 @@ def transform(self, tf: Transform) -> PointCloud2:
def voxel_downsample(self, voxel_size: float = 0.025) -> PointCloud2:
"""Downsample the pointcloud with a voxel grid."""
+ if voxel_size <= 0:
+ return self
if len(self.pointcloud.points) < 20:
return self
downsampled = self._pcd_tensor.voxel_down_sample(voxel_size)
@@ -626,6 +628,8 @@ def to_rerun( # type: ignore[no-untyped-def]
fill_mode: str = "solid",
**kwargs, # type: ignore[no-untyped-def]
): # type: ignore[no-untyped-def]
+ import rerun as rr
+
"""Convert to Rerun Points3D or Boxes3D archetype.
Args:
diff --git a/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py b/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py
index f5d92a3bc6..b71c5476fc 100644
--- a/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py
+++ b/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py
@@ -32,6 +32,31 @@
cp = None
HAS_CUDA = False
+# NVRTC defaults to C++11; libcu++ in recent CUDA requires at least C++17.
+if HAS_CUDA:
+ try:
+ import cupy.cuda.compiler as _cupy_compiler # type: ignore[import-not-found]
+
+ if not getattr(_cupy_compiler, "_dimos_force_cxx17", False):
+ _orig_compile_using_nvrtc = _cupy_compiler.compile_using_nvrtc
+
+ def _compile_using_nvrtc( # type: ignore[no-untyped-def]
+ source, options=(), *args, **kwargs
+ ):
+ filtered = tuple(
+ opt
+ for opt in options
+ if opt not in ("-std=c++11", "--std=c++11", "-std=c++14", "--std=c++14")
+ )
+ if "--std=c++17" not in filtered and "-std=c++17" not in filtered:
+ filtered = (*filtered, "--std=c++17")
+ return _orig_compile_using_nvrtc(source, filtered, *args, **kwargs)
+
+ _cupy_compiler.compile_using_nvrtc = _compile_using_nvrtc
+ _cupy_compiler._dimos_force_cxx17 = True
+ except Exception:
+ pass
+
# Optional nvImageCodec (preferred GPU codec)
USE_NVIMGCODEC = os.environ.get("USE_NVIMGCODEC", "0") == "1"
NVIMGCODEC_LAST_USED = False
diff --git a/dimos/msgs/sensor_msgs/image_impls/CudaImage.py b/dimos/msgs/sensor_msgs/image_impls/CudaImage.py
index 8230daae29..cdfa1bf088 100644
--- a/dimos/msgs/sensor_msgs/image_impls/CudaImage.py
+++ b/dimos/msgs/sensor_msgs/image_impls/CudaImage.py
@@ -262,13 +262,20 @@
} // extern "C"
"""
+_pnp_kernel = None
if cp is not None:
- _mod = cp.RawModule(code=_CUDA_SRC, options=("-std=c++14",), name_expressions=("pnp_gn_batch",))
- _pnp_kernel = _mod.get_function("pnp_gn_batch")
+ try:
+ _mod = cp.RawModule(
+ code=_CUDA_SRC, options=("-std=c++17",), name_expressions=("pnp_gn_batch",)
+ )
+ _pnp_kernel = _mod.get_function("pnp_gn_batch")
+ except Exception:
+ # CUDA not available at runtime (e.g., no GPU or driver issues)
+ pass
def _solve_pnp_cuda_kernel(obj, img, K, iterations: int = 15, damping: float = 1e-6): # type: ignore[no-untyped-def]
- if cp is None:
+ if cp is None or _pnp_kernel is None:
raise RuntimeError("CuPy/CUDA not available")
obj_cu = cp.asarray(obj, dtype=cp.float32)
@@ -709,7 +716,7 @@ def sharpness(self) -> float:
magnitude = cp.hypot(gx, gy)
mean_mag = float(cp.asnumpy(magnitude.mean()))
except Exception:
- return 0.0
+ raise
if mean_mag <= 0:
return 0.0
return float(np.clip((np.log10(mean_mag + 1) - 1.7) / 2.0, 0.0, 1.0))
diff --git a/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py b/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py
index b1de0ac777..7951a095b3 100644
--- a/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py
+++ b/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py
@@ -221,19 +221,13 @@ def test_perf_alloc(alloc_timer) -> None:
def test_sharpness(alloc_timer) -> None:
"""Test sharpness computation with NumpyImage always, add CudaImage parity when available."""
arr = _prepare_image(ImageFormat.BGR, (64, 64, 3))
- cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR)
+ cpu = alloc_timer(arr, ImageFormat.BGR)[0]
# Always test CPU backend
s_cpu = cpu.sharpness
assert s_cpu >= 0 # Sharpness should be non-negative
assert s_cpu < 1000 # Reasonable upper bound
- # Optionally test GPU parity when CUDA is available
- if gpu is not None:
- s_gpu = gpu.sharpness
- # Values should be very close; minor border/rounding differences allowed
- assert abs(s_cpu - s_gpu) < 5e-2
-
def test_to_opencv(alloc_timer) -> None:
"""Test to_opencv conversion with NumpyImage always, add CudaImage parity when available."""
@@ -356,6 +350,7 @@ def test_perf_resize(alloc_timer) -> None:
print(f"resize (avg per call) cpu={cpu_t:.6f}s")
+@pytest.mark.integration
def test_perf_sharpness(alloc_timer) -> None:
"""Test sharpness performance with NumpyImage always, add CudaImage when available."""
arr = _prepare_image(ImageFormat.BGR, (480, 640, 3))
diff --git a/dimos/msgs/sensor_msgs/test_PointCloud2.py b/dimos/msgs/sensor_msgs/test_PointCloud2.py
index e5cd11da8c..652ff08921 100644
--- a/dimos/msgs/sensor_msgs/test_PointCloud2.py
+++ b/dimos/msgs/sensor_msgs/test_PointCloud2.py
@@ -26,7 +26,7 @@
ROSHeader = None
from dimos.msgs.sensor_msgs import PointCloud2
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.robot.unitree_webrtc.type.lidar import pointcloud2_from_webrtc_lidar
from dimos.utils.testing import SensorReplay
# Try to import ROS types for testing
@@ -38,8 +38,8 @@
def test_lcm_encode_decode() -> None:
"""Test LCM encode/decode preserves pointcloud data."""
- replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg)
- lidar_msg: LidarMessage = replay.load_one("lidar_data_021")
+ replay = SensorReplay("office_lidar", autocast=pointcloud2_from_webrtc_lidar)
+ lidar_msg: PointCloud2 = replay.load_one("lidar_data_021")
binary_msg = lidar_msg.lcm_encode()
decoded = PointCloud2.lcm_decode(binary_msg)
diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py
index 29e890de47..54eaaf9215 100644
--- a/dimos/msgs/tf2_msgs/TFMessage.py
+++ b/dimos/msgs/tf2_msgs/TFMessage.py
@@ -164,7 +164,10 @@ def to_rerun(self): # type: ignore[no-untyped-def]
"""Convert to a list of rerun Transform3D archetypes.
Returns a list of tuples (entity_path, Transform3D) for each transform
- in the message. The entity_path is derived from the child_frame_id.
+ in the message. The entity_path is derived from the child_frame_id and
+ logged under `world/tf/...` so it is visible under the default `world`
+ origin while keeping TF visualization isolated from semantic entities
+ like `world/robot/...`.
Returns:
List of (entity_path, rr.Transform3D) tuples
@@ -175,6 +178,6 @@ def to_rerun(self): # type: ignore[no-untyped-def]
"""
results = []
for transform in self.transforms:
- entity_path = f"world/{transform.child_frame_id}"
+ entity_path = f"world/tf/{transform.child_frame_id}"
results.append((entity_path, transform.to_rerun())) # type: ignore[no-untyped-call]
return results
diff --git a/dimos/msgs/vision_msgs/Detection2D.py b/dimos/msgs/vision_msgs/Detection2D.py
new file mode 100644
index 0000000000..aa957f8061
--- /dev/null
+++ b/dimos/msgs/vision_msgs/Detection2D.py
@@ -0,0 +1,27 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dimos_lcm.vision_msgs.Detection2D import Detection2D as LCMDetection2D
+
+from dimos.types.timestamped import to_timestamp
+
+
+class Detection2D(LCMDetection2D): # type: ignore[misc]
+ msg_name = "vision_msgs.Detection2D"
+
+ # for _get_field_type() to work when decoding in _decode_one()
+ __annotations__ = LCMDetection2D.__annotations__
+
+ @property
+ def ts(self) -> float:
+ return to_timestamp(self.header.stamp)
diff --git a/dimos/msgs/vision_msgs/Detection3D.py b/dimos/msgs/vision_msgs/Detection3D.py
new file mode 100644
index 0000000000..e074ecb0b1
--- /dev/null
+++ b/dimos/msgs/vision_msgs/Detection3D.py
@@ -0,0 +1,27 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dimos_lcm.vision_msgs.Detection3D import Detection3D as LCMDetection3D
+
+from dimos.types.timestamped import to_timestamp
+
+
+class Detection3D(LCMDetection3D): # type: ignore[misc]
+ msg_name = "vision_msgs.Detection3D"
+
+ # for _get_field_type() to work when decoding in _decode_one()
+ __annotations__ = LCMDetection3D.__annotations__
+
+ @property
+ def ts(self) -> float:
+ return to_timestamp(self.header.stamp)
diff --git a/dimos/msgs/vision_msgs/Detection3DArray.py b/dimos/msgs/vision_msgs/Detection3DArray.py
index 59905cad4c..2eba82204d 100644
--- a/dimos/msgs/vision_msgs/Detection3DArray.py
+++ b/dimos/msgs/vision_msgs/Detection3DArray.py
@@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from dimos_lcm.vision_msgs.Detection3DArray import Detection3DArray as LCMDetection3DArray
-from dimos_lcm.vision_msgs.Detection3DArray import (
- Detection3DArray as LCMDetection3DArray,
-)
+from dimos.types.timestamped import to_timestamp
class Detection3DArray(LCMDetection3DArray): # type: ignore[misc]
msg_name = "vision_msgs.Detection3DArray"
+
+ # for _get_field_type() to work when decoding in _decode_one()
+ __annotations__ = LCMDetection3DArray.__annotations__
+
+ @property
+ def ts(self) -> float:
+ return to_timestamp(self.header.stamp)
diff --git a/dimos/msgs/vision_msgs/__init__.py b/dimos/msgs/vision_msgs/__init__.py
index af170cbfab..0f1c9c8dc1 100644
--- a/dimos/msgs/vision_msgs/__init__.py
+++ b/dimos/msgs/vision_msgs/__init__.py
@@ -1,6 +1,15 @@
from .BoundingBox2DArray import BoundingBox2DArray
from .BoundingBox3DArray import BoundingBox3DArray
+from .Detection2D import Detection2D
from .Detection2DArray import Detection2DArray
+from .Detection3D import Detection3D
from .Detection3DArray import Detection3DArray
-__all__ = ["BoundingBox2DArray", "BoundingBox3DArray", "Detection2DArray", "Detection3DArray"]
+__all__ = [
+ "BoundingBox2DArray",
+ "BoundingBox3DArray",
+ "Detection2D",
+ "Detection2DArray",
+ "Detection3D",
+ "Detection3DArray",
+]
diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py
index aca154a6dd..1c8082b414 100644
--- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py
+++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py
@@ -15,12 +15,10 @@
import time
import numpy as np
-from PIL import ImageDraw
import pytest
from dimos.msgs.geometry_msgs import Vector3
from dimos.msgs.nav_msgs import CostValues, OccupancyGrid
-from dimos.navigation.frontier_exploration.utils import costmap_to_pil_image
from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import (
WavefrontFrontierExplorer,
)
@@ -312,92 +310,6 @@ def test_exploration_with_no_gain_detection() -> None:
explorer.stop()
-@pytest.mark.vis
-def test_frontier_detection_visualization() -> None:
- """Test frontier detection with visualization (marked with @pytest.mark.vis)."""
- # Get test costmap
- costmap, first_lidar = create_test_costmap()
-
- # Initialize frontier explorer with default parameters
- explorer = WavefrontFrontierExplorer()
-
- try:
- # Use lidar origin as robot position
- robot_pose = first_lidar.origin
-
- # Detect all frontiers for visualization
- all_frontiers = explorer.detect_frontiers(robot_pose, costmap)
-
- # Get selected goal
- selected_goal = explorer.get_exploration_goal(robot_pose, costmap)
-
- print(f"Visualizing {len(all_frontiers)} frontier candidates")
- if selected_goal:
- print(f"Selected goal: ({selected_goal.x:.2f}, {selected_goal.y:.2f})")
-
- # Create visualization
- image_scale_factor = 4
- base_image = costmap_to_pil_image(costmap, image_scale_factor)
-
- # Helper function to convert world coordinates to image coordinates
- def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]:
- grid_pos = costmap.world_to_grid(world_pos)
- img_x = int(grid_pos.x * image_scale_factor)
- img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y
- return img_x, img_y
-
- # Draw visualization
- draw = ImageDraw.Draw(base_image)
-
- # Draw frontier candidates as gray dots
- for frontier in all_frontiers[:20]: # Limit to top 20
- x, y = world_to_image_coords(frontier)
- radius = 6
- draw.ellipse(
- [x - radius, y - radius, x + radius, y + radius],
- fill=(128, 128, 128), # Gray
- outline=(64, 64, 64),
- width=1,
- )
-
- # Draw robot position as blue dot
- robot_x, robot_y = world_to_image_coords(robot_pose)
- robot_radius = 10
- draw.ellipse(
- [
- robot_x - robot_radius,
- robot_y - robot_radius,
- robot_x + robot_radius,
- robot_y + robot_radius,
- ],
- fill=(0, 0, 255), # Blue
- outline=(0, 0, 128),
- width=3,
- )
-
- # Draw selected goal as red dot
- if selected_goal:
- goal_x, goal_y = world_to_image_coords(selected_goal)
- goal_radius = 12
- draw.ellipse(
- [
- goal_x - goal_radius,
- goal_y - goal_radius,
- goal_x + goal_radius,
- goal_y + goal_radius,
- ],
- fill=(255, 0, 0), # Red
- outline=(128, 0, 0),
- width=3,
- )
-
- # Display the image
- base_image.show(title="Frontier Detection - Office Lidar")
- print("Visualization displayed. Close the image window to continue.")
- finally:
- explorer.stop()
-
-
def test_performance_timing() -> None:
"""Test performance by timing frontier detection operations."""
import time
diff --git a/dimos/navigation/replanning_a_star/local_planner.py b/dimos/navigation/replanning_a_star/local_planner.py
index cc5f6164dc..65a18d0637 100644
--- a/dimos/navigation/replanning_a_star/local_planner.py
+++ b/dimos/navigation/replanning_a_star/local_planner.py
@@ -29,7 +29,7 @@
from dimos.msgs.nav_msgs import Path
from dimos.msgs.sensor_msgs import Image
from dimos.navigation.base import NavigationState
-from dimos.navigation.replanning_a_star.controllers import Controller, PController, PdController
+from dimos.navigation.replanning_a_star.controllers import Controller, PController
from dimos.navigation.replanning_a_star.navigation_map import NavigationMap
from dimos.navigation.replanning_a_star.path_clearance import PathClearance
from dimos.navigation.replanning_a_star.path_distancer import PathDistancer
@@ -87,9 +87,7 @@ def __init__(
self._navigation_map = navigation_map
self._goal_tolerance = goal_tolerance
- controller = PController if global_config.simulation else PdController
-
- self._controller = controller(
+ self._controller = PController(
self._global_config,
self._speed,
self._control_frequency,
diff --git a/dimos/navigation/visual_servoing/detection_navigation.py b/dimos/navigation/visual_servoing/detection_navigation.py
new file mode 100644
index 0000000000..5f89bd1faa
--- /dev/null
+++ b/dimos/navigation/visual_servoing/detection_navigation.py
@@ -0,0 +1,208 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dimos_lcm.sensor_msgs import CameraInfo as DimosLcmCameraInfo
+import numpy as np
+
+from dimos.msgs.geometry_msgs import Transform, Twist, Vector3
+from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
+from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox
+from dimos.perception.detection.type.detection3d import Detection3DPC
+from dimos.protocol.tf import LCMTF
+from dimos.utils.logging_config import setup_logger
+
+logger = setup_logger()
+
+
+class DetectionNavigation:
+ _target_distance_3d: float = 1.5 # meters to maintain from person
+ _min_distance_3d: float = 0.8 # meters before backing up
+ _max_linear_speed_3d: float = 0.5 # m/s
+ _max_angular_speed_3d: float = 0.8 # rad/s
+ _linear_gain_3d: float = 0.8
+ _angular_gain_3d: float = 1.5
+
+ _tf: LCMTF
+ _camera_info: CameraInfo
+
+ def __init__(self, tf: LCMTF, camera_info: CameraInfo) -> None:
+ self._tf = tf
+ self._camera_info = camera_info
+
+ def compute_twist_for_detection_3d(
+ self, pointcloud: PointCloud2, detection: Detection2DBBox, image: Image
+ ) -> Twist | None:
+ """Project a 2D detection to 3D using pointcloud and compute navigation twist.
+
+ Args:
+ detection: 2D detection with bounding box
+ image: Current image frame
+
+ Returns:
+ Twist command to navigate towards the detection's 3D position.
+ """
+
+ # Get transform from world frame to camera optical frame
+ world_to_optical = self._tf.get(
+ "camera_optical", pointcloud.frame_id, image.ts, time_tolerance=1.0
+ )
+ if world_to_optical is None:
+ logger.warning("Could not get camera transform")
+ return None
+
+ lcm_camera_info = DimosLcmCameraInfo()
+ lcm_camera_info.K = self._camera_info.K
+ lcm_camera_info.width = self._camera_info.width
+ lcm_camera_info.height = self._camera_info.height
+
+ # Project to 3D using the pointcloud
+ detection_3d = Detection3DPC.from_2d(
+ det=detection,
+ world_pointcloud=pointcloud,
+ camera_info=lcm_camera_info,
+ world_to_optical_transform=world_to_optical,
+ filters=[], # Skip filtering for faster processing in follow loop
+ )
+
+ if detection_3d is None:
+ logger.warning("3D projection failed")
+ return None
+
+ # Get robot position to compute robust target
+ robot_transform = self._tf.get("world", "base_link", time_tolerance=1.0)
+ if robot_transform is None:
+ logger.warning("Could not get robot transform")
+ return None
+
+ robot_pos = robot_transform.translation
+
+ # Compute robust target position using front-most points
+ target_position = self._compute_robust_target_position(detection_3d.pointcloud, robot_pos)
+ if target_position is None:
+ logger.warning("Could not compute robust target position")
+ return None
+
+ return self._compute_twist_from_3d(target_position, robot_transform)
+
+ def _compute_robust_target_position(
+ self, pointcloud: PointCloud2, robot_pos: Vector3
+ ) -> Vector3 | None:
+ """Compute a robust target position from the detection pointcloud.
+
+ Instead of using the centroid of all points (which includes floor/background),
+ this method:
+ 1. Filters out floor points (z < 0.3m in world frame)
+ 2. Computes distance from robot to each remaining point
+ 3. Uses the 25th percentile of closest points to get the front surface
+ 4. Returns the centroid of those front-most points
+
+ Args:
+ pointcloud: The detection's pointcloud in world frame
+ robot_pos: Robot's current position in world frame
+
+ Returns:
+ Vector3 position representing the front of the detected object,
+ or None if not enough valid points.
+ """
+ points, _ = pointcloud.as_numpy()
+ if len(points) < 10:
+ return None
+
+ # Filter out floor points (keep points above 0.3m height)
+ height_mask = points[:, 2] > 0.3
+ points = points[height_mask]
+ if len(points) < 10:
+ # Fall back to all points if height filtering removes too many
+ points, _ = pointcloud.as_numpy()
+
+ # Compute 2D distance (XY plane) from robot to each point
+ dx = points[:, 0] - robot_pos.x
+ dy = points[:, 1] - robot_pos.y
+ distances = np.sqrt(dx * dx + dy * dy)
+
+ # Use 25th percentile of distances to find front-most points
+ distance_threshold = np.percentile(distances, 25)
+
+ # Get points that are within the front 25%
+ front_mask = distances <= distance_threshold
+ front_points = points[front_mask]
+
+ if len(front_points) < 3:
+ # Fall back to median distance point
+ median_dist = np.median(distances)
+ close_mask = np.abs(distances - median_dist) < 0.3
+ front_points = points[close_mask]
+ if len(front_points) < 3:
+ return None
+
+ # Compute centroid of front-most points
+ centroid = front_points.mean(axis=0)
+ return Vector3(centroid[0], centroid[1], centroid[2])
+
+ def _compute_twist_from_3d(self, target_position: Vector3, robot_transform: Transform) -> Twist:
+ """Compute twist command to navigate towards a 3D target position.
+
+ Args:
+ target_position: 3D position of the target in world frame.
+ robot_transform: Robot's current transform in world frame.
+
+ Returns:
+ Twist command for the robot.
+ """
+ robot_pos = robot_transform.translation
+
+ # Compute vector from robot to target in world frame
+ dx = target_position.x - robot_pos.x
+ dy = target_position.y - robot_pos.y
+ distance = np.sqrt(dx * dx + dy * dy)
+ print(f"Distance to target: {distance:.2f} m")
+
+ # Compute angle to target in world frame
+ angle_to_target = np.arctan2(dy, dx)
+
+ # Get robot's current heading from transform
+ robot_yaw = robot_transform.rotation.to_euler().z
+
+ # Angle error (how much to turn)
+ angle_error = angle_to_target - robot_yaw
+ # Normalize to [-pi, pi]
+ while angle_error > np.pi:
+ angle_error -= 2 * np.pi
+ while angle_error < -np.pi:
+ angle_error += 2 * np.pi
+
+ # Compute angular velocity (turn towards target)
+ angular_z = angle_error * self._angular_gain_3d
+ angular_z = float(
+ np.clip(angular_z, -self._max_angular_speed_3d, self._max_angular_speed_3d)
+ )
+
+ # Compute linear velocity based on distance
+ distance_error = distance - self._target_distance_3d
+
+ if distance < self._min_distance_3d:
+ # Too close, back up
+ linear_x = -self._max_linear_speed_3d * 0.6
+ else:
+ # Move forward based on distance error, reduce speed when turning
+ turn_factor = 1.0 - min(abs(angle_error) / np.pi, 0.7)
+ linear_x = distance_error * self._linear_gain_3d * turn_factor
+ linear_x = float(
+ np.clip(linear_x, -self._max_linear_speed_3d, self._max_linear_speed_3d)
+ )
+
+ return Twist(
+ linear=Vector3(linear_x, 0.0, 0.0),
+ angular=Vector3(0.0, 0.0, angular_z),
+ )
diff --git a/dimos/navigation/visual_servoing/visual_servoing_2d.py b/dimos/navigation/visual_servoing/visual_servoing_2d.py
new file mode 100644
index 0000000000..032b5f3370
--- /dev/null
+++ b/dimos/navigation/visual_servoing/visual_servoing_2d.py
@@ -0,0 +1,166 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from dimos.msgs.geometry_msgs import Twist, Vector3
+from dimos.msgs.sensor_msgs import CameraInfo
+
+
+class VisualServoing2D:
+ """2D visual servoing controller for tracking objects using bounding boxes.
+
+ Uses camera intrinsics to convert pixel coordinates to normalized camera
+ coordinates and estimates distance based on known object width.
+ """
+
+ # Target distance to maintain from object (meters).
+ _target_distance: float = 1.5
+
+ # Minimum distance before backing up (meters).
+ _min_distance: float = 0.8
+
+ # Maximum forward/backward speed (m/s).
+ _max_linear_speed: float = 0.5
+
+ # Maximum turning speed (rad/s).
+ _max_angular_speed: float = 0.8
+
+ # Assumed real-world width of tracked object (meters).
+ _assumed_object_width: float = 0.45
+
+ # Proportional gain for angular velocity control.
+ _angular_gain: float = 1.0
+
+ # Proportional gain for linear velocity control.
+ _linear_gain: float = 0.8
+
+ # Speed factor when backing up (multiplied by max_linear_speed).
+ _backup_speed_factor: float = 0.6
+
+ # Multiplier for x_norm when calculating turn factor.
+ _turn_factor_multiplier: float = 2.0
+
+ # Maximum speed reduction due to turning (turn_factor ranges from 1-this to 1).
+ _turn_factor_max_reduction: float = 0.7
+
+ _rotation_requires_linear_movement: bool = False
+
+ # Camera intrinsics for coordinate conversion.
+ _camera_info: CameraInfo
+
+ def __init__(
+ self, camera_info: CameraInfo, rotation_requires_linear_movement: bool = False
+ ) -> None:
+ self._camera_info = camera_info
+ self._rotation_requires_linear_movement = rotation_requires_linear_movement
+
+ def compute_twist(
+ self,
+ bbox: tuple[float, float, float, float],
+ image_width: int,
+ ) -> Twist:
+ """Compute twist command to servo towards the tracked object.
+
+ Args:
+ bbox: Bounding box (x1, y1, x2, y2) in pixels.
+ image_width: Width of the image.
+
+ Returns:
+ Twist command for the robot.
+ """
+ x1, _, x2, _ = bbox
+ bbox_center_x = (x1 + x2) / 2.0
+
+ # Get normalized x coordinate using inverse K matrix
+ # Positive = object is to the right of optical center
+ x_norm = self._get_normalized_x(bbox_center_x)
+
+ estimated_distance = self._estimate_distance(bbox)
+
+ if estimated_distance is None:
+ return Twist.zero()
+
+ # Calculate distance error (positive = too far, need to move forward)
+ distance_error = estimated_distance - self._target_distance
+
+ # Compute angular velocity (turn towards object)
+ # Negative because positive angular.z is counter-clockwise (left turn)
+ angular_z = -x_norm * self._angular_gain
+ angular_z = float(np.clip(angular_z, -self._max_angular_speed, self._max_angular_speed))
+
+ # Compute linear velocity - ALWAYS move forward/backward based on distance.
+ # Reduce forward speed when turning sharply to maintain stability.
+ turn_factor = 1.0 - min(
+ abs(x_norm) * self._turn_factor_multiplier, self._turn_factor_max_reduction
+ )
+
+ if estimated_distance < self._min_distance:
+ # Too close, back up (don't reduce speed for backing up)
+ linear_x = -self._max_linear_speed * self._backup_speed_factor
+ else:
+ # Move forward based on distance error with proportional gain
+ linear_x = distance_error * self._linear_gain * turn_factor
+ linear_x = float(np.clip(linear_x, -self._max_linear_speed, self._max_linear_speed))
+
+ # Enforce minimum linear speed when turning
+ if self._rotation_requires_linear_movement and abs(angular_z) < 0.02:
+ linear_x = max(linear_x, 0.1)
+
+ return Twist(
+ linear=Vector3(linear_x, 0.0, 0.0),
+ angular=Vector3(0.0, 0.0, angular_z),
+ )
+
+ def _get_normalized_x(self, pixel_x: float) -> float:
+ """Convert pixel x coordinate to normalized camera coordinate.
+
+ Uses inverse K matrix: x_norm = (pixel_x - cx) / fx
+
+ Args:
+ pixel_x: x coordinate in pixels
+
+ Returns:
+ Normalized x coordinate (tan of angle from optical center)
+ """
+ fx = self._camera_info.K[0] # focal length x
+ cx = self._camera_info.K[2] # optical center x
+ return (pixel_x - cx) / fx
+
+ def _estimate_distance(self, bbox: tuple[float, float, float, float]) -> float | None:
+ """Estimate distance to object based on bounding box size and camera intrinsics.
+
+ Uses the pinhole camera model:
+ pixel_width / fx = real_width / distance
+ distance = (real_width * fx) / pixel_width
+
+ Uses bbox width instead of height because ground robot can't see full
+ person height when close. Width (shoulders) is more consistently visible.
+
+ Args:
+ bbox: Bounding box (x1, y1, x2, y2) in pixels.
+
+ Returns:
+ Estimated distance in meters, or None if bbox is invalid.
+ """
+ bbox_width = bbox[2] - bbox[0] # x2 - x1
+
+ if bbox_width <= 0:
+ return None
+
+ # Pinhole camera model: distance = (real_width * fx) / pixel_width
+ fx = self._camera_info.K[0] # focal length x in pixels
+ estimated_distance = (self._assumed_object_width * fx) / bbox_width
+
+ return estimated_distance
diff --git a/dimos/perception/common/__init__.py b/dimos/perception/common/__init__.py
index 67481bc449..16281fe0b6 100644
--- a/dimos/perception/common/__init__.py
+++ b/dimos/perception/common/__init__.py
@@ -1,3 +1 @@
-from .detection2d_tracker import get_tracked_results, target2dTracker
-from .ibvs import *
from .utils import *
diff --git a/dimos/perception/common/detection2d_tracker.py b/dimos/perception/common/detection2d_tracker.py
deleted file mode 100644
index 9ff36be8a1..0000000000
--- a/dimos/perception/common/detection2d_tracker.py
+++ /dev/null
@@ -1,396 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections import deque
-from collections.abc import Sequence
-
-import numpy as np
-
-
-def compute_iou(bbox1, bbox2): # type: ignore[no-untyped-def]
- """
- Compute Intersection over Union (IoU) of two bounding boxes.
- Each bbox is [x1, y1, x2, y2].
- """
- x1 = max(bbox1[0], bbox2[0])
- y1 = max(bbox1[1], bbox2[1])
- x2 = min(bbox1[2], bbox2[2])
- y2 = min(bbox1[3], bbox2[3])
-
- inter_area = max(0, x2 - x1) * max(0, y2 - y1)
- area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
- area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
-
- union_area = area1 + area2 - inter_area
- if union_area == 0:
- return 0
- return inter_area / union_area
-
-
-def get_tracked_results(tracked_targets): # type: ignore[no-untyped-def]
- """
- Extract tracked results from a list of target2d objects.
-
- Args:
- tracked_targets (list[target2d]): List of target2d objects (published targets)
- returned by the tracker's update() function.
-
- Returns:
- tuple: (tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names)
- where each is a list of the corresponding attribute from each target.
- """
- tracked_masks = []
- tracked_bboxes = []
- tracked_track_ids = []
- tracked_probs = []
- tracked_names = []
-
- for target in tracked_targets:
- # Extract the latest values stored in each target.
- tracked_masks.append(target.latest_mask)
- tracked_bboxes.append(target.latest_bbox)
- # Here we use the most recent detection's track ID.
- tracked_track_ids.append(target.target_id)
- # Use the latest probability from the history.
- tracked_probs.append(target.score)
- # Use the stored name (if any). If not available, you can use a default value.
- tracked_names.append(target.name)
-
- return tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names
-
-
-class target2d:
- """
- Represents a tracked 2D target.
- Stores the latest bounding box and mask along with a short history of track IDs,
- detection probabilities, and computed texture values.
- """
-
- def __init__( # type: ignore[no-untyped-def]
- self,
- initial_mask,
- initial_bbox,
- track_id,
- prob: float,
- name: str,
- texture_value,
- target_id,
- history_size: int = 10,
- ) -> None:
- """
- Args:
- initial_mask (torch.Tensor): Latest segmentation mask.
- initial_bbox (list): Bounding box in [x1, y1, x2, y2] format.
- track_id (int): Detection’s track ID (may be -1 if not provided).
- prob (float): Detection probability.
- name (str): Object class name.
- texture_value (float): Computed average texture value for this detection.
- target_id (int): Unique identifier assigned by the tracker.
- history_size (int): Maximum number of frames to keep in the history.
- """
- self.target_id = target_id
- self.latest_mask = initial_mask
- self.latest_bbox = initial_bbox
- self.name = name
- self.score = 1.0
-
- self.track_id = track_id
- self.probs_history = deque(maxlen=history_size) # type: ignore[var-annotated]
- self.texture_history = deque(maxlen=history_size) # type: ignore[var-annotated]
-
- self.frame_count = deque(maxlen=history_size) # type: ignore[var-annotated] # Total frames this target has been seen.
- self.missed_frames = 0 # Consecutive frames when no detection was assigned.
- self.history_size = history_size
-
- def update(self, mask, bbox, track_id, prob: float, name: str, texture_value) -> None: # type: ignore[no-untyped-def]
- """
- Update the target with a new detection.
- """
- self.latest_mask = mask
- self.latest_bbox = bbox
- self.name = name
-
- self.track_id = track_id
- self.probs_history.append(prob)
- self.texture_history.append(texture_value)
-
- self.frame_count.append(1)
- self.missed_frames = 0
-
- def mark_missed(self) -> None:
- """
- Increment the count of consecutive frames where this target was not updated.
- """
- self.missed_frames += 1
- self.frame_count.append(0)
-
- def compute_score( # type: ignore[no-untyped-def]
- self,
- frame_shape,
- min_area_ratio,
- max_area_ratio,
- texture_range=(0.0, 1.0),
- border_safe_distance: int = 50,
- weights=None,
- ):
- """
- Compute a combined score for the target based on several factors.
-
- Factors:
- - **Detection probability:** Average over recent frames.
- - **Temporal stability:** How consistently the target has appeared.
- - **Texture quality:** Normalized using the provided min and max values.
- - **Border proximity:** Computed from the minimum distance from the bbox to the frame edges.
- - **Size:** How the object's area (relative to the frame) compares to acceptable bounds.
-
- Args:
- frame_shape (tuple): (height, width) of the frame.
- min_area_ratio (float): Minimum acceptable ratio (bbox area / frame area).
- max_area_ratio (float): Maximum acceptable ratio.
- texture_range (tuple): (min_texture, max_texture) expected values.
- border_safe_distance (float): Distance (in pixels) considered safe from the border.
- weights (dict): Weights for each component. Expected keys:
- 'prob', 'temporal', 'texture', 'border', and 'size'.
-
- Returns:
- float: The combined (normalized) score in the range [0, 1].
- """
- # Default weights if none provided.
- if weights is None:
- weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0}
-
- h, w = frame_shape
- x1, y1, x2, y2 = self.latest_bbox
- bbox_area = (x2 - x1) * (y2 - y1)
- frame_area = w * h
- area_ratio = bbox_area / frame_area
-
- # Detection probability factor.
- avg_prob = np.mean(self.probs_history)
- # Temporal stability factor: normalized by history size.
- temporal_stability = np.mean(self.frame_count)
- # Texture factor: normalize average texture using the provided range.
- avg_texture = np.mean(self.texture_history) if self.texture_history else 0.0
- min_texture, max_texture = texture_range
- if max_texture == min_texture:
- normalized_texture = avg_texture
- else:
- normalized_texture = (avg_texture - min_texture) / (max_texture - min_texture)
- normalized_texture = max(0.0, min(normalized_texture, 1.0))
-
- # Border factor: compute the minimum distance from the bbox to any frame edge.
- left_dist = x1
- top_dist = y1
- right_dist = w - x2
- min_border_dist = min(left_dist, top_dist, right_dist)
- # Normalize the border distance: full score (1.0) if at least border_safe_distance away.
- border_factor = min(1.0, min_border_dist / border_safe_distance)
-
- # Size factor: penalize objects that are too small or too big.
- if area_ratio < min_area_ratio:
- size_factor = area_ratio / min_area_ratio
- elif area_ratio > max_area_ratio:
- # Here we compute a linear penalty if the area exceeds max_area_ratio.
- if 1 - max_area_ratio > 0:
- size_factor = max(0, (1 - area_ratio) / (1 - max_area_ratio))
- else:
- size_factor = 0.0
- else:
- size_factor = 1.0
-
- # Combine factors using a weighted sum (each factor is assumed in [0, 1]).
- w_prob = weights.get("prob", 1.0)
- w_temporal = weights.get("temporal", 1.0)
- w_texture = weights.get("texture", 1.0)
- w_border = weights.get("border", 1.0)
- w_size = weights.get("size", 1.0)
- total_weight = w_prob + w_temporal + w_texture + w_border + w_size
-
- # print(f"track_id: {self.target_id}, avg_prob: {avg_prob:.2f}, temporal_stability: {temporal_stability:.2f}, normalized_texture: {normalized_texture:.2f}, border_factor: {border_factor:.2f}, size_factor: {size_factor:.2f}")
-
- final_score = (
- w_prob * avg_prob
- + w_temporal * temporal_stability
- + w_texture * normalized_texture
- + w_border * border_factor
- + w_size * size_factor
- ) / total_weight
-
- self.score = final_score
-
- return final_score
-
-
-class target2dTracker:
- """
- Tracker that maintains a history of targets across frames.
- New segmentation detections (frame, masks, bboxes, track_ids, probabilities,
- and computed texture values) are matched to existing targets or used to create new ones.
-
- The tracker uses a scoring system that incorporates:
- - **Detection probability**
- - **Temporal stability**
- - **Texture quality** (normalized within a specified range)
- - **Proximity to image borders** (a continuous penalty based on the distance)
- - **Object size** relative to the frame
-
- Targets are published if their score exceeds the start threshold and are removed if their score
- falls below the stop threshold or if they are missed for too many consecutive frames.
- """
-
- def __init__( # type: ignore[no-untyped-def]
- self,
- history_size: int = 10,
- score_threshold_start: float = 0.5,
- score_threshold_stop: float = 0.3,
- min_frame_count: int = 10,
- max_missed_frames: int = 3,
- min_area_ratio: float = 0.001,
- max_area_ratio: float = 0.1,
- texture_range=(0.0, 1.0),
- border_safe_distance: int = 50,
- weights=None,
- ) -> None:
- """
- Args:
- history_size (int): Maximum history length (number of frames) per target.
- score_threshold_start (float): Minimum score for a target to be published.
- score_threshold_stop (float): If a target’s score falls below this, it is removed.
- min_frame_count (int): Minimum number of frames a target must be seen to be published.
- max_missed_frames (int): Maximum consecutive frames a target can be missing before deletion.
- min_area_ratio (float): Minimum acceptable bbox area relative to the frame.
- max_area_ratio (float): Maximum acceptable bbox area relative to the frame.
- texture_range (tuple): (min_texture, max_texture) expected values.
- border_safe_distance (float): Distance (in pixels) considered safe from the border.
- weights (dict): Weights for the scoring components (keys: 'prob', 'temporal',
- 'texture', 'border', 'size').
- """
- self.history_size = history_size
- self.score_threshold_start = score_threshold_start
- self.score_threshold_stop = score_threshold_stop
- self.min_frame_count = min_frame_count
- self.max_missed_frames = max_missed_frames
- self.min_area_ratio = min_area_ratio
- self.max_area_ratio = max_area_ratio
- self.texture_range = texture_range
- self.border_safe_distance = border_safe_distance
- # Default weights if none are provided.
- if weights is None:
- weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0}
- self.weights = weights
-
- self.targets = {} # type: ignore[var-annotated] # Dictionary mapping target_id -> target2d instance.
- self.next_target_id = 0
-
- def update( # type: ignore[no-untyped-def]
- self,
- frame,
- masks,
- bboxes,
- track_ids,
- probs: Sequence[float],
- names: Sequence[str],
- texture_values,
- ):
- """
- Update the tracker with new detections from the current frame.
-
- Args:
- frame (np.ndarray): Current BGR frame.
- masks (list[torch.Tensor]): List of segmentation masks.
- bboxes (list): List of bounding boxes [x1, y1, x2, y2].
- track_ids (list): List of detection track IDs.
- probs (list): List of detection probabilities.
- names (list): List of class names.
- texture_values (list): List of computed texture values.
-
- Returns:
- published_targets (list[target2d]): Targets that are active and have scores above
- the start threshold.
- """
- updated_target_ids = set()
- frame_shape = frame.shape[:2] # (height, width)
-
- # For each detection, try to match with an existing target.
- for mask, bbox, det_tid, prob, name, texture in zip(
- masks, bboxes, track_ids, probs, names, texture_values, strict=False
- ):
- matched_target = None
-
- # First, try matching by detection track ID if valid.
- if det_tid != -1:
- for target in self.targets.values():
- if target.track_id == det_tid:
- matched_target = target
- break
-
- # Otherwise, try matching using IoU.
- if matched_target is None:
- best_iou = 0
- for target in self.targets.values():
- iou = compute_iou(bbox, target.latest_bbox) # type: ignore[no-untyped-call]
- if iou > 0.5 and iou > best_iou:
- best_iou = iou
- matched_target = target
-
- # Update existing target or create a new one.
- if matched_target is not None:
- matched_target.update(mask, bbox, det_tid, prob, name, texture)
- updated_target_ids.add(matched_target.target_id)
- else:
- new_target = target2d(
- mask, bbox, det_tid, prob, name, texture, self.next_target_id, self.history_size
- )
- self.targets[self.next_target_id] = new_target
- updated_target_ids.add(self.next_target_id)
- self.next_target_id += 1
-
- # Mark targets that were not updated.
- for target_id, target in list(self.targets.items()):
- if target_id not in updated_target_ids:
- target.mark_missed()
- if target.missed_frames > self.max_missed_frames:
- del self.targets[target_id]
- continue # Skip further checks for this target.
- # Remove targets whose score falls below the stop threshold.
- score = target.compute_score(
- frame_shape,
- self.min_area_ratio,
- self.max_area_ratio,
- texture_range=self.texture_range,
- border_safe_distance=self.border_safe_distance,
- weights=self.weights,
- )
- if score < self.score_threshold_stop:
- del self.targets[target_id]
-
- # Publish targets with scores above the start threshold.
- published_targets = []
- for target in self.targets.values():
- score = target.compute_score(
- frame_shape,
- self.min_area_ratio,
- self.max_area_ratio,
- texture_range=self.texture_range,
- border_safe_distance=self.border_safe_distance,
- weights=self.weights,
- )
- if (
- score >= self.score_threshold_start
- and sum(target.frame_count) >= self.min_frame_count
- and target.missed_frames <= 5
- ):
- published_targets.append(target)
-
- return published_targets
diff --git a/dimos/perception/common/export_tensorrt.py b/dimos/perception/common/export_tensorrt.py
deleted file mode 100644
index ca671e36f2..0000000000
--- a/dimos/perception/common/export_tensorrt.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import argparse
-
-from ultralytics import YOLO, FastSAM # type: ignore[attr-defined, import-not-found]
-
-
-def parse_args(): # type: ignore[no-untyped-def]
- parser = argparse.ArgumentParser(description="Export YOLO/FastSAM models to different formats")
- parser.add_argument("--model_path", type=str, required=True, help="Path to the model weights")
- parser.add_argument(
- "--model_type",
- type=str,
- choices=["yolo", "fastsam"],
- required=True,
- help="Type of model to export",
- )
- parser.add_argument(
- "--precision",
- type=str,
- choices=["fp32", "fp16", "int8"],
- default="fp32",
- help="Precision for export",
- )
- parser.add_argument(
- "--format", type=str, choices=["onnx", "engine"], default="onnx", help="Export format"
- )
- return parser.parse_args()
-
-
-def main() -> None:
- args = parse_args() # type: ignore[no-untyped-call]
- half = args.precision == "fp16"
- int8 = args.precision == "int8"
- # Load the appropriate model
- if args.model_type == "yolo":
- model: YOLO | FastSAM = YOLO(args.model_path)
- else:
- model = FastSAM(args.model_path)
-
- # Export the model
- model.export(format=args.format, half=half, int8=int8)
-
-
-if __name__ == "__main__":
- main()
diff --git a/dimos/perception/common/ibvs.py b/dimos/perception/common/ibvs.py
deleted file mode 100644
index e24819f432..0000000000
--- a/dimos/perception/common/ibvs.py
+++ /dev/null
@@ -1,280 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import numpy as np
-
-
-class PersonDistanceEstimator:
- def __init__(self, K, camera_pitch, camera_height) -> None: # type: ignore[no-untyped-def]
- """
- Initialize the distance estimator using ground plane constraint.
-
- Args:
- K: 3x3 Camera intrinsic matrix in OpenCV format
- (Assumed to be already for an undistorted image)
- camera_pitch: Upward pitch of the camera (in radians), in the robot frame
- Positive means looking up, negative means looking down
- camera_height: Height of the camera above the ground (in meters)
- """
- self.K = K
- self.camera_height = camera_height
-
- # Precompute the inverse intrinsic matrix
- self.K_inv = np.linalg.inv(K)
-
- # Transform from camera to robot frame (z-forward to x-forward)
- self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])
-
- # Pitch rotation matrix (positive is upward)
- theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y
- self.R_pitch = np.array(
- [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]
- )
-
- # Combined transform from camera to robot frame
- self.A = self.R_pitch @ self.T
-
- # Store focal length and principal point for angle calculation
- self.fx = K[0, 0]
- self.cx = K[0, 2]
-
- def estimate_distance_angle(self, bbox: tuple, robot_pitch: float | None = None): # type: ignore[no-untyped-def, type-arg]
- """
- Estimate distance and angle to person using ground plane constraint.
-
- Args:
- bbox: tuple (x_min, y_min, x_max, y_max)
- where y_max represents the feet position
- robot_pitch: Current pitch of the robot body (in radians)
- If provided, this will be combined with the camera's fixed pitch
-
- Returns:
- depth: distance to person along camera's z-axis (meters)
- angle: horizontal angle in camera frame (radians, positive right)
- """
- x_min, _, x_max, y_max = bbox
-
- # Get center point of feet
- u_c = (x_min + x_max) / 2.0
- v_feet = y_max
-
- # Create homogeneous feet point and get ray direction
- p_feet = np.array([u_c, v_feet, 1.0])
- d_feet_cam = self.K_inv @ p_feet
-
- # If robot_pitch is provided, recalculate the transformation matrix
- if robot_pitch is not None:
- # Combined pitch (fixed camera pitch + current robot pitch)
- total_pitch = -camera_pitch - robot_pitch # Both negated for correct rotation direction
- R_total_pitch = np.array(
- [
- [np.cos(total_pitch), 0, np.sin(total_pitch)],
- [0, 1, 0],
- [-np.sin(total_pitch), 0, np.cos(total_pitch)],
- ]
- )
- # Use the updated transformation matrix
- A = R_total_pitch @ self.T
- else:
- # Use the precomputed transformation matrix
- A = self.A
-
- # Convert ray to robot frame using appropriate transformation
- d_feet_robot = A @ d_feet_cam
-
- # Ground plane intersection (z=0)
- # camera_height + t * d_feet_robot[2] = 0
- if abs(d_feet_robot[2]) < 1e-6:
- raise ValueError("Feet ray is parallel to ground plane")
-
- # Solve for scaling factor t
- t = -self.camera_height / d_feet_robot[2]
-
- # Get 3D feet position in robot frame
- p_feet_robot = t * d_feet_robot
-
- # Convert back to camera frame
- p_feet_cam = self.A.T @ p_feet_robot
-
- # Extract depth (z-coordinate in camera frame)
- depth = p_feet_cam[2]
-
- # Calculate horizontal angle from image center
- angle = np.arctan((u_c - self.cx) / self.fx)
-
- return depth, angle
-
-
-class ObjectDistanceEstimator:
- """
- Estimate distance to an object using the ground plane constraint.
- This class assumes the camera is mounted on a robot and uses the
- camera's intrinsic parameters to estimate the distance to a detected object.
- """
-
- def __init__(self, K, camera_pitch, camera_height) -> None: # type: ignore[no-untyped-def]
- """
- Initialize the distance estimator using ground plane constraint.
-
- Args:
- K: 3x3 Camera intrinsic matrix in OpenCV format
- (Assumed to be already for an undistorted image)
- camera_pitch: Upward pitch of the camera (in radians)
- Positive means looking up, negative means looking down
- camera_height: Height of the camera above the ground (in meters)
- """
- self.K = K
- self.camera_height = camera_height
-
- # Precompute the inverse intrinsic matrix
- self.K_inv = np.linalg.inv(K)
-
- # Transform from camera to robot frame (z-forward to x-forward)
- self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])
-
- # Pitch rotation matrix (positive is upward)
- theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y
- self.R_pitch = np.array(
- [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]
- )
-
- # Combined transform from camera to robot frame
- self.A = self.R_pitch @ self.T
-
- # Store focal length and principal point for angle calculation
- self.fx = K[0, 0]
- self.fy = K[1, 1]
- self.cx = K[0, 2]
- self.estimated_object_size = None
-
- def estimate_object_size(self, bbox: tuple, distance: float): # type: ignore[no-untyped-def, type-arg]
- """
- Estimate the physical size of an object based on its bbox and known distance.
-
- Args:
- bbox: tuple (x_min, y_min, x_max, y_max) bounding box in the image
- distance: Known distance to the object (in meters)
- robot_pitch: Current pitch of the robot body (in radians), if any
-
- Returns:
- estimated_size: Estimated physical height of the object (in meters)
- """
- _x_min, y_min, _x_max, y_max = bbox
-
- # Calculate object height in pixels
- object_height_px = y_max - y_min
-
- # Calculate the physical height using the known distance and focal length
- estimated_size = object_height_px * distance / self.fy
- self.estimated_object_size = estimated_size
-
- return estimated_size
-
- def set_estimated_object_size(self, size: float) -> None:
- """
- Set the estimated object size for future distance calculations.
-
- Args:
- size: Estimated physical size of the object (in meters)
- """
- self.estimated_object_size = size # type: ignore[assignment]
-
- def estimate_distance_angle(self, bbox: tuple): # type: ignore[no-untyped-def, type-arg]
- """
- Estimate distance and angle to object using size-based estimation.
-
- Args:
- bbox: tuple (x_min, y_min, x_max, y_max)
- where y_max represents the bottom of the object
- robot_pitch: Current pitch of the robot body (in radians)
- If provided, this will be combined with the camera's fixed pitch
- initial_distance: Initial distance estimate for the object (in meters)
- Used to calibrate object size if not previously known
-
- Returns:
- depth: distance to object along camera's z-axis (meters)
- angle: horizontal angle in camera frame (radians, positive right)
- or None, None if estimation not possible
- """
- # If we don't have estimated object size and no initial distance is provided,
- # we can't estimate the distance
- if self.estimated_object_size is None:
- return None, None
-
- x_min, y_min, x_max, y_max = bbox
-
- # Calculate center of the object for angle calculation
- u_c = (x_min + x_max) / 2.0
-
- # If we have an initial distance estimate and no object size yet,
- # calculate and store the object size using the initial distance
- object_height_px = y_max - y_min
- depth = self.estimated_object_size * self.fy / object_height_px
-
- # Calculate horizontal angle from image center
- angle = np.arctan((u_c - self.cx) / self.fx)
-
- return depth, angle
-
-
-# Example usage:
-if __name__ == "__main__":
- # Example camera calibration
- K = np.array([[600, 0, 320], [0, 600, 240], [0, 0, 1]], dtype=np.float32)
-
- # Camera mounted 1.2m high, pitched down 10 degrees
- camera_pitch = np.deg2rad(0) # negative for downward pitch
- camera_height = 1.0 # meters
-
- estimator = PersonDistanceEstimator(K, camera_pitch, camera_height)
- object_estimator = ObjectDistanceEstimator(K, camera_pitch, camera_height)
-
- # Example detection
- bbox = (300, 100, 380, 400) # x1, y1, x2, y2
-
- depth, angle = estimator.estimate_distance_angle(bbox)
- # Estimate object size based on the known distance
- object_size = object_estimator.estimate_object_size(bbox, depth)
- depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox)
-
- print(f"Estimated person depth: {depth:.2f} m")
- print(f"Estimated person angle: {np.rad2deg(angle):.1f}°")
- print(f"Estimated object depth: {depth_obj:.2f} m")
- print(f"Estimated object angle: {np.rad2deg(angle_obj):.1f}°")
-
- # Shrink the bbox by 30 pixels while keeping the same center
- x_min, y_min, x_max, y_max = bbox
- width = x_max - x_min
- height = y_max - y_min
- center_x = (x_min + x_max) // 2
- center_y = (y_min + y_max) // 2
-
- new_width = max(width - 20, 2) # Ensure width is at least 2 pixels
- new_height = max(height - 20, 2) # Ensure height is at least 2 pixels
-
- x_min = center_x - new_width // 2
- x_max = center_x + new_width // 2
- y_min = center_y - new_height // 2
- y_max = center_y + new_height // 2
-
- bbox = (x_min, y_min, x_max, y_max)
-
- # Re-estimate distance and angle with the new bbox
- depth, angle = estimator.estimate_distance_angle(bbox)
- depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox)
-
- print(f"New estimated person depth: {depth:.2f} m")
- print(f"New estimated person angle: {np.rad2deg(angle):.1f}°")
- print(f"New estimated object depth: {depth_obj:.2f} m")
- print(f"New estimated object angle: {np.rad2deg(angle_obj):.1f}°")
diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py
new file mode 100644
index 0000000000..d1d879d0ab
--- /dev/null
+++ b/dimos/perception/demo_object_scene_registration.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dimos.agents.agent import llm_agent
+from dimos.agents.cli.human import human_input
+from dimos.core.blueprints import autoconnect
+from dimos.hardware.sensors.camera.realsense import realsense_camera
+from dimos.hardware.sensors.camera.zed import zed_camera
+from dimos.perception.detection.detectors.yoloe import YoloePromptMode
+from dimos.perception.object_scene_registration import object_scene_registration_module
+from dimos.robot.foxglove_bridge import foxglove_bridge
+
+camera_choice = "zed"
+
+if camera_choice == "realsense":
+ camera_module = realsense_camera(enable_pointcloud=False)
+elif camera_choice == "zed":
+ camera_module = zed_camera(enable_pointcloud=False)
+else:
+ raise ValueError(f"Invalid camera choice: {camera_choice}")
+
+demo_object_scene_registration = autoconnect(
+ camera_module,
+ object_scene_registration_module(target_frame="world", prompt_mode=YoloePromptMode.LRPC),
+ foxglove_bridge(),
+ human_input(),
+ llm_agent(),
+).global_config(viewer_backend="foxglove")
diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py
index 1c9c8ca05c..8c6953e410 100644
--- a/dimos/perception/detection/conftest.py
+++ b/dimos/perception/detection/conftest.py
@@ -36,7 +36,6 @@
)
from dimos.protocol.tf import TF
from dimos.robot.unitree.connection import go2
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
from dimos.robot.unitree_webrtc.type.odometry import Odometry
from dimos.utils.data import get_data
from dimos.utils.testing import TimedSensorReplay
@@ -44,7 +43,7 @@
class Moment(TypedDict, total=False):
odom_frame: Odometry
- lidar_frame: LidarMessage
+ lidar_frame: PointCloud2
image_frame: Image
camera_info: CameraInfo
transforms: list[Transform]
@@ -83,7 +82,7 @@ def moment_provider(**kwargs) -> Moment:
lidar_frame_result = TimedSensorReplay(f"{data_dir}/lidar").find_closest_seek(seek)
if lidar_frame_result is None:
raise ValueError("No lidar frame found")
- lidar_frame: LidarMessage = lidar_frame_result
+ lidar_frame: PointCloud2 = lidar_frame_result
image_frame = TimedSensorReplay(
f"{data_dir}/video",
diff --git a/dimos/perception/detection/detectors/conftest.py b/dimos/perception/detection/detectors/conftest.py
index 9cb600aeff..6a2c041a8b 100644
--- a/dimos/perception/detection/detectors/conftest.py
+++ b/dimos/perception/detection/detectors/conftest.py
@@ -17,6 +17,7 @@
from dimos.msgs.sensor_msgs import Image
from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector
from dimos.perception.detection.detectors.yolo import Yolo2DDetector
+from dimos.perception.detection.detectors.yoloe import Yoloe2DDetector, YoloePromptMode
from dimos.utils.data import get_data
@@ -36,3 +37,9 @@ def person_detector():
def bbox_detector():
"""Create a Yolo2DDetector instance for general object detection."""
return Yolo2DDetector()
+
+
+@pytest.fixture(scope="session")
+def yoloe_detector():
+ """Create a Yoloe2DDetector instance for general object detection."""
+ return Yoloe2DDetector(prompt_mode=YoloePromptMode.LRPC)
diff --git a/dimos/perception/detection/detectors/detic.py b/dimos/perception/detection/detectors/detic.py
deleted file mode 100644
index 288a3e056d..0000000000
--- a/dimos/perception/detection/detectors/detic.py
+++ /dev/null
@@ -1,426 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections.abc import Sequence
-import os
-import sys
-
-import numpy as np
-
-# Add Detic to Python path
-from dimos.constants import DIMOS_PROJECT_ROOT
-from dimos.msgs.sensor_msgs import Image
-from dimos.perception.detection.detectors.types import Detector
-from dimos.perception.detection2d.utils import plot_results
-
-detic_path = DIMOS_PROJECT_ROOT / "dimos/models/Detic"
-if str(detic_path) not in sys.path:
- sys.path.append(str(detic_path))
- sys.path.append(str(detic_path / "third_party/CenterNet2"))
-
-# PIL patch for compatibility
-import PIL.Image
-
-if not hasattr(PIL.Image, "LINEAR") and hasattr(PIL.Image, "BILINEAR"):
- PIL.Image.LINEAR = PIL.Image.BILINEAR # type: ignore[attr-defined]
-
-# Detectron2 imports
-from detectron2.config import get_cfg # type: ignore[import-not-found]
-from detectron2.data import MetadataCatalog # type: ignore[import-not-found]
-
-
-# Simple tracking implementation
-class SimpleTracker:
- """Simple IOU-based tracker implementation without external dependencies"""
-
- def __init__(self, iou_threshold: float = 0.3, max_age: int = 5) -> None:
- self.iou_threshold = iou_threshold
- self.max_age = max_age
- self.next_id = 1
- self.tracks = {} # type: ignore[var-annotated] # id -> {bbox, class_id, age, mask, etc}
-
- def _calculate_iou(self, bbox1, bbox2): # type: ignore[no-untyped-def]
- """Calculate IoU between two bboxes in format [x1,y1,x2,y2]"""
- x1 = max(bbox1[0], bbox2[0])
- y1 = max(bbox1[1], bbox2[1])
- x2 = min(bbox1[2], bbox2[2])
- y2 = min(bbox1[3], bbox2[3])
-
- if x2 < x1 or y2 < y1:
- return 0.0
-
- intersection = (x2 - x1) * (y2 - y1)
- area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
- area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
- union = area1 + area2 - intersection
-
- return intersection / union if union > 0 else 0
-
- def update(self, detections, masks): # type: ignore[no-untyped-def]
- """Update tracker with new detections
-
- Args:
- detections: List of [x1,y1,x2,y2,score,class_id]
- masks: List of segmentation masks corresponding to detections
-
- Returns:
- List of [track_id, bbox, score, class_id, mask]
- """
- if len(detections) == 0:
- # Age existing tracks
- for track_id in list(self.tracks.keys()):
- self.tracks[track_id]["age"] += 1
- # Remove old tracks
- if self.tracks[track_id]["age"] > self.max_age:
- del self.tracks[track_id]
- return []
-
- # Convert to numpy for easier handling
- if not isinstance(detections, np.ndarray):
- detections = np.array(detections)
-
- result = []
- matched_indices = set()
-
- # Update existing tracks
- for track_id, track in list(self.tracks.items()):
- track["age"] += 1
-
- if track["age"] > self.max_age:
- del self.tracks[track_id]
- continue
-
- # Find best matching detection for this track
- best_iou = self.iou_threshold
- best_idx = -1
-
- for i, det in enumerate(detections):
- if i in matched_indices:
- continue
-
- # Check class match
- if det[5] != track["class_id"]:
- continue
-
- iou = self._calculate_iou(track["bbox"], det[:4]) # type: ignore[no-untyped-call]
- if iou > best_iou:
- best_iou = iou
- best_idx = i
-
- # If we found a match, update the track
- if best_idx >= 0:
- self.tracks[track_id]["bbox"] = detections[best_idx][:4]
- self.tracks[track_id]["score"] = detections[best_idx][4]
- self.tracks[track_id]["age"] = 0
- self.tracks[track_id]["mask"] = masks[best_idx]
- matched_indices.add(best_idx)
-
- # Add to results with mask
- result.append(
- [
- track_id,
- detections[best_idx][:4],
- detections[best_idx][4],
- int(detections[best_idx][5]),
- self.tracks[track_id]["mask"],
- ]
- )
-
- # Create new tracks for unmatched detections
- for i, det in enumerate(detections):
- if i in matched_indices:
- continue
-
- # Create new track
- new_id = self.next_id
- self.next_id += 1
-
- self.tracks[new_id] = {
- "bbox": det[:4],
- "score": det[4],
- "class_id": int(det[5]),
- "age": 0,
- "mask": masks[i],
- }
-
- # Add to results with mask directly from the track
- result.append([new_id, det[:4], det[4], int(det[5]), masks[i]])
-
- return result
-
-
-class Detic2DDetector(Detector):
- def __init__( # type: ignore[no-untyped-def]
- self, model_path=None, device: str = "cuda", vocabulary=None, threshold: float = 0.5
- ) -> None:
- """
- Initialize the Detic detector with open vocabulary support.
-
- Args:
- model_path (str): Path to a custom Detic model weights (optional)
- device (str): Device to run inference on ('cuda' or 'cpu')
- vocabulary (list): Custom vocabulary (list of class names) or 'lvis', 'objects365', 'openimages', 'coco'
- threshold (float): Detection confidence threshold
- """
- self.device = device
- self.threshold = threshold
-
- # Set up Detic paths - already added to sys.path at module level
-
- # Import Detic modules
- from centernet.config import add_centernet_config # type: ignore[import-not-found]
- from detic.config import add_detic_config # type: ignore[import-not-found]
- from detic.modeling.text.text_encoder import ( # type: ignore[import-not-found]
- build_text_encoder,
- )
- from detic.modeling.utils import reset_cls_test # type: ignore[import-not-found]
-
- # Keep reference to these functions for later use
- self.reset_cls_test = reset_cls_test
- self.build_text_encoder = build_text_encoder
-
- # Setup model configuration
- self.cfg = get_cfg()
- add_centernet_config(self.cfg)
- add_detic_config(self.cfg)
-
- # Use default Detic config
- self.cfg.merge_from_file(
- os.path.join(
- detic_path, "configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml"
- )
- )
-
- # Set default weights if not provided
- if model_path is None:
- self.cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth"
- else:
- self.cfg.MODEL.WEIGHTS = model_path
-
- # Set device
- if device == "cpu":
- self.cfg.MODEL.DEVICE = "cpu"
-
- # Set detection threshold
- self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
- self.cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand"
- self.cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True
-
- # Built-in datasets for Detic - use absolute paths with detic_path
- self.builtin_datasets = {
- "lvis": {
- "metadata": "lvis_v1_val",
- "classifier": os.path.join(
- detic_path, "datasets/metadata/lvis_v1_clip_a+cname.npy"
- ),
- },
- "objects365": {
- "metadata": "objects365_v2_val",
- "classifier": os.path.join(
- detic_path, "datasets/metadata/o365_clip_a+cnamefix.npy"
- ),
- },
- "openimages": {
- "metadata": "oid_val_expanded",
- "classifier": os.path.join(detic_path, "datasets/metadata/oid_clip_a+cname.npy"),
- },
- "coco": {
- "metadata": "coco_2017_val",
- "classifier": os.path.join(detic_path, "datasets/metadata/coco_clip_a+cname.npy"),
- },
- }
-
- # Override config paths to use absolute paths
- self.cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = os.path.join(
- detic_path, "datasets/metadata/lvis_v1_train_cat_info.json"
- )
-
- # Initialize model
- self.predictor = None
-
- # Setup with initial vocabulary
- vocabulary = vocabulary or "lvis"
- self.setup_vocabulary(vocabulary) # type: ignore[no-untyped-call]
-
- # Initialize our simple tracker
- self.tracker = SimpleTracker(iou_threshold=0.5, max_age=5)
-
- def setup_vocabulary(self, vocabulary): # type: ignore[no-untyped-def]
- """
- Setup the model's vocabulary.
-
- Args:
- vocabulary: Either a string ('lvis', 'objects365', 'openimages', 'coco')
- or a list of class names for custom vocabulary.
- """
- if self.predictor is None:
- # Initialize the model
- from detectron2.engine import DefaultPredictor # type: ignore[import-not-found]
-
- self.predictor = DefaultPredictor(self.cfg)
-
- if isinstance(vocabulary, str) and vocabulary in self.builtin_datasets:
- # Use built-in dataset
- dataset = vocabulary
- metadata = MetadataCatalog.get(self.builtin_datasets[dataset]["metadata"])
- classifier = self.builtin_datasets[dataset]["classifier"]
- num_classes = len(metadata.thing_classes)
- self.class_names = metadata.thing_classes
- else:
- # Use custom vocabulary
- if isinstance(vocabulary, str):
- # If it's a string but not a built-in dataset, treat as a file
- try:
- with open(vocabulary) as f:
- class_names = [line.strip() for line in f if line.strip()]
- except:
- # Default to LVIS if there's an issue
- print(f"Error loading vocabulary from {vocabulary}, using LVIS")
- return self.setup_vocabulary("lvis") # type: ignore[no-untyped-call]
- else:
- # Assume it's a list of class names
- class_names = vocabulary
-
- # Create classifier from text embeddings
- metadata = MetadataCatalog.get("__unused")
- metadata.thing_classes = class_names
- self.class_names = class_names
-
- # Generate CLIP embeddings for custom vocabulary
- classifier = self._get_clip_embeddings(class_names)
- num_classes = len(class_names)
-
- # Reset model with new vocabulary
- self.reset_cls_test(self.predictor.model, classifier, num_classes) # type: ignore[attr-defined]
- return self.class_names
-
- def _get_clip_embeddings(self, vocabulary, prompt: str = "a "): # type: ignore[no-untyped-def]
- """
- Generate CLIP embeddings for a vocabulary list.
-
- Args:
- vocabulary (list): List of class names
- prompt (str): Prompt prefix to use for CLIP
-
- Returns:
- torch.Tensor: Tensor of embeddings
- """
- text_encoder = self.build_text_encoder(pretrain=True)
- text_encoder.eval()
- texts = [prompt + x for x in vocabulary]
- emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
- return emb
-
- def process_image(self, image: Image): # type: ignore[no-untyped-def]
- """
- Process an image and return detection results.
-
- Args:
- image: Input image in BGR format (OpenCV)
-
- Returns:
- tuple: (bboxes, track_ids, class_ids, confidences, names, masks)
- - bboxes: list of [x1, y1, x2, y2] coordinates
- - track_ids: list of tracking IDs (or -1 if no tracking)
- - class_ids: list of class indices
- - confidences: list of detection confidences
- - names: list of class names
- - masks: list of segmentation masks (numpy arrays)
- """
- # Run inference with Detic
- outputs = self.predictor(image.to_opencv()) # type: ignore[misc]
- instances = outputs["instances"].to("cpu")
-
- # Extract bounding boxes, classes, scores, and masks
- if len(instances) == 0:
- return [], [], [], [], [] # , []
-
- boxes = instances.pred_boxes.tensor.numpy()
- class_ids = instances.pred_classes.numpy()
- scores = instances.scores.numpy()
- masks = instances.pred_masks.numpy()
-
- # Convert boxes to [x1, y1, x2, y2] format
- bboxes = []
- for box in boxes:
- x1, y1, x2, y2 = box.tolist()
- bboxes.append([x1, y1, x2, y2])
-
- # Get class names
- [self.class_names[class_id] for class_id in class_ids]
-
- # Apply tracking
- detections = []
- filtered_masks = []
- for i, bbox in enumerate(bboxes):
- if scores[i] >= self.threshold:
- # Format for tracker: [x1, y1, x2, y2, score, class_id]
- detections.append([*bbox, scores[i], class_ids[i]])
- filtered_masks.append(masks[i])
-
- if not detections:
- return [], [], [], [], [] # , []
-
- # Update tracker with detections and correctly aligned masks
- track_results = self.tracker.update(detections, filtered_masks) # type: ignore[no-untyped-call]
-
- # Process tracking results
- track_ids = []
- tracked_bboxes = []
- tracked_class_ids = []
- tracked_scores = []
- tracked_names = []
- tracked_masks = []
-
- for track_id, bbox, score, class_id, mask in track_results:
- track_ids.append(int(track_id))
- tracked_bboxes.append(bbox.tolist() if isinstance(bbox, np.ndarray) else bbox)
- tracked_class_ids.append(int(class_id))
- tracked_scores.append(score)
- tracked_names.append(self.class_names[int(class_id)])
- tracked_masks.append(mask)
-
- return (
- tracked_bboxes,
- track_ids,
- tracked_class_ids,
- tracked_scores,
- tracked_names,
- # tracked_masks,
- )
-
- def visualize_results( # type: ignore[no-untyped-def]
- self, image, bboxes, track_ids, class_ids, confidences, names: Sequence[str]
- ):
- """
- Generate visualization of detection results.
-
- Args:
- image: Original input image
- bboxes: List of bounding boxes
- track_ids: List of tracking IDs
- class_ids: List of class indices
- confidences: List of detection confidences
- names: List of class names
-
- Returns:
- Image with visualized detections
- """
-
- return plot_results(image, bboxes, track_ids, class_ids, confidences, names)
-
- def cleanup(self) -> None:
- """Clean up resources."""
- # Nothing specific to clean up for Detic
- pass
diff --git a/dimos/perception/detection/detectors/test_bbox_detectors.py b/dimos/perception/detection/detectors/test_bbox_detectors.py
index bd9c1358b5..32a509061a 100644
--- a/dimos/perception/detection/detectors/test_bbox_detectors.py
+++ b/dimos/perception/detection/detectors/test_bbox_detectors.py
@@ -12,21 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations
import pytest
+from reactivex.disposable import CompositeDisposable
+from dimos.core import LCMTransport
+from dimos.msgs.sensor_msgs import Image
from dimos.perception.detection.type import Detection2D, ImageDetections2D
-@pytest.fixture(params=["bbox_detector", "person_detector"], scope="session")
+@pytest.fixture(params=["bbox_detector", "person_detector", "yoloe_detector"], scope="session")
def detector(request):
"""Parametrized fixture that provides both bbox and person detectors."""
return request.getfixturevalue(request.param)
@pytest.fixture(scope="session")
-def detections(detector, test_image):
+def get_topic_annotations():
+ disposables = CompositeDisposable()
+
+ def topic_annotations(suffix: str = "unnamed"):
+ annotations: LCMTransport[ImageAnnotations] = LCMTransport(
+ f"/annotations_{suffix}", ImageAnnotations
+ )
+ disposables.add(annotations)
+ return annotations
+
+ yield topic_annotations
+ disposables.dispose()
+
+
+@pytest.fixture(scope="session")
+def detections(detector, test_image, topic_image, get_topic_annotations):
"""Get ImageDetections2D from any detector."""
- return detector.process_image(test_image)
+ topic_image.publish(test_image)
+ detections = detector.process_image(test_image)
+ annotations = detections.to_foxglove_annotations()
+ print("annotations:", annotations)
+ topic_annotations = get_topic_annotations(detector.__class__.__name__)
+ topic_annotations.publish(annotations)
+ return detections
+
+
+@pytest.fixture(scope="session")
+def topic_image():
+ image: LCMTransport[Image] = LCMTransport("/color_image", Image)
+ yield image
+ image.lcm.stop()
def test_detection_basic(detections) -> None:
diff --git a/dimos/perception/detection/detectors/yoloe.py b/dimos/perception/detection/detectors/yoloe.py
new file mode 100644
index 0000000000..9c9881209c
--- /dev/null
+++ b/dimos/perception/detection/detectors/yoloe.py
@@ -0,0 +1,177 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+import threading
+from typing import Any
+
+import numpy as np
+from numpy.typing import NDArray
+from ultralytics import YOLOE # type: ignore[attr-defined, import-not-found]
+
+from dimos.msgs.sensor_msgs import Image
+from dimos.perception.detection.detectors.types import Detector
+from dimos.perception.detection.type import ImageDetections2D
+from dimos.utils.data import get_data
+from dimos.utils.gpu_utils import is_cuda_available
+
+
+class YoloePromptMode(Enum):
+ """YOLO-E prompt modes."""
+
+ LRPC = "lrpc"
+ PROMPT = "prompt"
+
+
+class Yoloe2DDetector(Detector):
+ def __init__(
+ self,
+ model_path: str = "models_yoloe",
+ model_name: str | None = None,
+ device: str | None = None,
+ prompt_mode: YoloePromptMode = YoloePromptMode.LRPC,
+ exclude_class_ids: list[int] | None = None,
+ max_area_ratio: float | None = 0.3,
+ ) -> None:
+ """
+ Initialize YOLO-E 2D detector.
+
+ Args:
+ model_path: Path to model directory (fetched via get_data from LFS).
+ model_name: Model filename. Defaults based on prompt_mode.
+ device: Device to run inference on ('cuda', 'cpu', or None for auto).
+ prompt_mode: LRPC for prompt-free detection, PROMPT for text/visual prompting.
+ exclude_class_ids: Class IDs to filter out from results (pass [] to disable).
+ max_area_ratio: Maximum bbox area ratio (0-1) relative to image.
+ """
+ if model_name is None:
+ if prompt_mode == YoloePromptMode.LRPC:
+ model_name = "yoloe-11s-seg-pf.pt"
+ else:
+ model_name = "yoloe-11s-seg.pt"
+
+ self.model = YOLOE(get_data(model_path) / model_name)
+ self.prompt_mode = prompt_mode
+ self._visual_prompts: dict[str, NDArray[Any]] | None = None
+ self.max_area_ratio = max_area_ratio
+ self._lock = threading.Lock()
+
+ if prompt_mode == YoloePromptMode.PROMPT:
+ self.set_prompts(text=["nothing"])
+ self.exclude_class_ids = set(exclude_class_ids) if exclude_class_ids else set()
+
+ if self.max_area_ratio is not None and not (0.0 < self.max_area_ratio <= 1.0):
+ raise ValueError("max_area_ratio must be in the range (0, 1].")
+
+ if device:
+ self.device = device
+ elif is_cuda_available(): # type: ignore[no-untyped-call]
+ self.device = "cuda"
+ else:
+ self.device = "cpu"
+
+ def set_prompts(
+ self,
+ text: list[str] | None = None,
+ bboxes: NDArray[np.float64] | None = None,
+ ) -> None:
+ """
+ Set prompts for detection. Provide either text or bboxes, not both.
+
+ Args:
+ text: List of class names to detect.
+ bboxes: Bounding boxes in xyxy format, shape (N, 4).
+ """
+ if text is not None and bboxes is not None:
+ raise ValueError("Provide either text or bboxes, not both.")
+ if text is None and bboxes is None:
+ raise ValueError("Must provide either text or bboxes.")
+
+ with self._lock:
+ self.model.predictor = None
+ if text is not None:
+ self.model.set_classes(text, self.model.get_text_pe(text)) # type: ignore[no-untyped-call]
+ self._visual_prompts = None
+ else:
+ cls = np.arange(len(bboxes), dtype=np.int16) # type: ignore[arg-type]
+ self._visual_prompts = {"bboxes": bboxes, "cls": cls} # type: ignore[dict-item]
+
+ def process_image(self, image: Image) -> "ImageDetections2D[Any]":
+ """
+ Process an image and return detection results.
+
+ Args:
+ image: Input image
+
+ Returns:
+ ImageDetections2D containing all detected objects
+ """
+ track_kwargs = {
+ "source": image.to_opencv(),
+ "device": self.device,
+ "conf": 0.6,
+ "iou": 0.6,
+ "persist": True,
+ "verbose": False,
+ }
+
+ with self._lock:
+ if self._visual_prompts is not None:
+ track_kwargs["visual_prompts"] = self._visual_prompts
+
+ results = self.model.track(**track_kwargs) # type: ignore[arg-type]
+
+ detections = ImageDetections2D.from_ultralytics_result(image, results)
+ return self._apply_filters(image, detections)
+
+ def _apply_filters(
+ self,
+ image: Image,
+ detections: "ImageDetections2D[Any]",
+ ) -> "ImageDetections2D[Any]":
+ if not self.exclude_class_ids and self.max_area_ratio is None:
+ return detections
+
+ predicates = []
+
+ if self.exclude_class_ids:
+ predicates.append(lambda det: det.class_id not in self.exclude_class_ids)
+
+ if self.max_area_ratio is not None:
+ image_area = image.width * image.height
+
+ def area_filter(det): # type: ignore[no-untyped-def]
+ if image_area <= 0:
+ return True
+ return (det.bbox_2d_volume() / image_area) <= self.max_area_ratio
+
+ predicates.append(area_filter)
+
+ filtered = detections.detections
+ for predicate in predicates:
+ filtered = [det for det in filtered if predicate(det)] # type: ignore[no-untyped-call]
+
+ return ImageDetections2D(image, filtered)
+
+ def stop(self) -> None:
+ """Clean up resources used by the detector."""
+ if hasattr(self.model, "predictor") and self.model.predictor is not None:
+ predictor = self.model.predictor
+ if hasattr(predictor, "trackers") and predictor.trackers:
+ for tracker in predictor.trackers:
+ if hasattr(tracker, "tracker") and hasattr(tracker.tracker, "gmc"):
+ gmc = tracker.tracker.gmc
+ if hasattr(gmc, "executor") and gmc.executor is not None:
+ gmc.executor.shutdown(wait=True)
+ self.model.predictor = None
diff --git a/dimos/perception/detection/objectDB.py b/dimos/perception/detection/objectDB.py
new file mode 100644
index 0000000000..6ade2d8c8d
--- /dev/null
+++ b/dimos/perception/detection/objectDB.py
@@ -0,0 +1,312 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import threading
+import time
+from typing import TYPE_CHECKING, Any
+
+import open3d as o3d # type: ignore[import-untyped]
+
+from dimos.msgs.sensor_msgs import PointCloud2
+from dimos.utils.logging_config import setup_logger
+
+if TYPE_CHECKING:
+ from dimos.msgs.geometry_msgs import Vector3
+ from dimos.perception.detection.type.detection3d.object import Object
+
+logger = setup_logger()
+
+
+class ObjectDB:
+ """Spatial memory database for 3D object detections.
+
+ Maintains two tiers of objects internally:
+ - _pending_objects: Recently detected objects (detection_count < threshold)
+ - _objects: Confirmed permanent objects (detection_count >= threshold)
+
+ Deduplication uses two heuristics:
+ 1. track_id match from YOLOE tracker (recent match)
+ 2. Center distance within threshold (spatial proximity match)
+ """
+
+ def __init__(
+ self,
+ distance_threshold: float = 0.1,
+ min_detections_for_permanent: int = 6,
+ pending_ttl_s: float = 5.0,
+ track_id_ttl_s: float = 2.0,
+ ) -> None:
+ self._distance_threshold = distance_threshold
+ self._min_detections = min_detections_for_permanent
+ self._pending_ttl_s = pending_ttl_s
+ self._track_id_ttl_s = track_id_ttl_s
+
+ # Internal storage - keyed by object_id
+ self._pending_objects: dict[str, Object] = {}
+ self._objects: dict[str, Object] = {} # Permanent objects
+
+ # track_id -> object_id mapping for fast lookup
+ self._track_id_map: dict[int, str] = {}
+ self._last_add_stats: dict[str, int] = {}
+
+ self._lock = threading.RLock()
+
+ # ─────────────────────────────────────────────────────────────────
+ # Public Methods
+ # ─────────────────────────────────────────────────────────────────
+
+ def add_objects(self, objects: list[Object]) -> list[Object]:
+ """Add multiple objects to the database with deduplication.
+
+ Args:
+ objects: List of Object instances from object_scene_registration
+
+ Returns:
+ List of updated/created Object instances
+ """
+ stats = {
+ "input": len(objects),
+ "created": 0,
+ "updated": 0,
+ "promoted": 0,
+ "matched_track": 0,
+ "matched_distance": 0,
+ }
+
+ results: list[Object] = []
+ now = time.time()
+ with self._lock:
+ self._prune_stale_pending(now)
+ for obj in objects:
+ matched, reason = self._match(obj, now)
+ if matched is None:
+ results.append(self._insert_pending(obj, now))
+ stats["created"] += 1
+ continue
+
+ self._update_existing(matched, obj, now)
+ results.append(matched)
+ stats["updated"] += 1
+ if reason == "track":
+ stats["matched_track"] += 1
+ elif reason == "distance":
+ stats["matched_distance"] += 1
+ if self._check_promotion(matched):
+ stats["promoted"] += 1
+
+ stats["pending"] = len(self._pending_objects)
+ stats["permanent"] = len(self._objects)
+ self._last_add_stats = stats
+ return results
+
+ def get_last_add_stats(self) -> dict[str, int]:
+ with self._lock:
+ return dict(self._last_add_stats)
+
+ def get_objects(self) -> list[Object]:
+ """Get all permanent objects (detection_count >= threshold)."""
+ with self._lock:
+ return list(self._objects.values())
+
+ def get_all_objects(self) -> list[Object]:
+ """Get all objects (both pending and permanent)."""
+ with self._lock:
+ return list(self._pending_objects.values()) + list(self._objects.values())
+
+ def promote(self, object_id: str) -> bool:
+ """Promote an object from pending to permanent."""
+ with self._lock:
+ if object_id in self._pending_objects:
+ self._objects[object_id] = self._pending_objects.pop(object_id)
+ return True
+ return object_id in self._objects
+
+ def find_by_name(self, name: str) -> list[Object]:
+ """Find all permanent objects with matching name."""
+ with self._lock:
+ return [obj for obj in self._objects.values() if obj.name == name]
+
+ def find_nearest(
+ self,
+ position: Vector3,
+ name: str | None = None,
+ ) -> Object | None:
+ """Find nearest permanent object to a position, optionally filtered by name.
+
+ Args:
+ position: Position to search from
+ name: Optional name filter
+
+ Returns:
+ Nearest Object or None if no objects found
+ """
+ with self._lock:
+ candidates = [
+ obj
+ for obj in self._objects.values()
+ if obj.center is not None and (name is None or obj.name == name)
+ ]
+
+ if not candidates:
+ return None
+
+ return min(candidates, key=lambda obj: position.distance(obj.center)) # type: ignore[arg-type]
+
+ def clear(self) -> None:
+ """Clear all objects from the database."""
+ with self._lock:
+ # Drop Open3D pointcloud references before clearing to reduce shutdown warnings.
+ for obj in list(self._pending_objects.values()) + list(self._objects.values()):
+ obj.pointcloud = PointCloud2(
+ pointcloud=o3d.geometry.PointCloud(),
+ frame_id=obj.pointcloud.frame_id,
+ ts=obj.pointcloud.ts,
+ )
+ self._pending_objects.clear()
+ self._objects.clear()
+ self._track_id_map.clear()
+ logger.info("ObjectDB cleared")
+
+ def get_stats(self) -> dict[str, int]:
+ """Get statistics about the database."""
+ with self._lock:
+ return {
+ "pending_count": len(self._pending_objects),
+ "permanent_count": len(self._objects),
+ "total_count": len(self._pending_objects) + len(self._objects),
+ }
+
+ # ─────────────────────────────────────────────────────────────────
+ # Internal Methods
+ # ─────────────────────────────────────────────────────────────────
+
+ def _match(self, obj: Object, now: float) -> tuple[Object | None, str | None]:
+ if obj.track_id >= 0:
+ matched = self._match_by_track_id(obj.track_id, now)
+ if matched is not None:
+ return matched, "track"
+
+ matched = self._match_by_distance(obj)
+ if matched is not None:
+ return matched, "distance"
+ return None, None
+
+ def _insert_pending(self, obj: Object, now: float) -> Object:
+ if not obj.ts:
+ obj.ts = now
+ self._pending_objects[obj.object_id] = obj
+ if obj.track_id >= 0:
+ self._track_id_map[obj.track_id] = obj.object_id
+ logger.info(f"Created new pending object {obj.object_id} ({obj.name})")
+ return obj
+
+ def _update_existing(self, existing: Object, obj: Object, now: float) -> None:
+ existing.update_object(obj)
+ existing.ts = obj.ts or now
+ if obj.track_id >= 0:
+ self._track_id_map[obj.track_id] = existing.object_id
+
+ def _match_by_track_id(self, track_id: int, now: float) -> Object | None:
+ """Find object with matching track_id from YOLOE."""
+ if track_id < 0:
+ return None
+
+ object_id = self._track_id_map.get(track_id)
+ if object_id is None:
+ return None
+
+ # Check in permanent objects first
+ if object_id in self._objects:
+ obj = self._objects[object_id]
+ elif object_id in self._pending_objects:
+ obj = self._pending_objects[object_id]
+ else:
+ del self._track_id_map[track_id]
+ return None
+
+ last_seen = obj.ts if obj.ts else now
+ if now - last_seen > self._track_id_ttl_s:
+ del self._track_id_map[track_id]
+ return None
+
+ return obj
+
+ def _match_by_distance(self, obj: Object) -> Object | None:
+ """Find object within distance threshold."""
+ if obj.center is None:
+ return None
+
+ # Combine all objects and filter by valid center
+ all_objects = list(self._objects.values()) + list(self._pending_objects.values())
+ candidates = [
+ o
+ for o in all_objects
+ if o.center is not None and obj.center.distance(o.center) < self._distance_threshold
+ ]
+
+ if not candidates:
+ return None
+
+ return min(candidates, key=lambda o: obj.center.distance(o.center)) # type: ignore[union-attr]
+
+ def _prune_stale_pending(self, now: float) -> None:
+ if self._pending_ttl_s <= 0:
+ return
+ cutoff = now - self._pending_ttl_s
+ stale_ids = [
+ obj_id for obj_id, obj in self._pending_objects.items() if (obj.ts or now) < cutoff
+ ]
+ for obj_id in stale_ids:
+ del self._pending_objects[obj_id]
+ for track_id, mapped_id in list(self._track_id_map.items()):
+ if mapped_id == obj_id:
+ del self._track_id_map[track_id]
+
+ def _check_promotion(self, obj: Object) -> bool:
+ """Move object from pending to permanent if threshold met."""
+ if obj.detections_count >= self._min_detections:
+ # Check if it's in pending
+ if obj.object_id in self._pending_objects:
+ # Promote to permanent
+ del self._pending_objects[obj.object_id]
+ self._objects[obj.object_id] = obj
+ logger.info(
+ f"Promoted object {obj.object_id} ({obj.name}) to permanent "
+ f"with {obj.detections_count} detections"
+ )
+ return True
+ return False
+
+ # ─────────────────────────────────────────────────────────────────
+ # Agent encoding
+ # ─────────────────────────────────────────────────────────────────
+
+ def agent_encode(self) -> list[dict[str, Any]]:
+ """Encode permanent objects for agent consumption."""
+ with self._lock:
+ return [obj.agent_encode() for obj in self._objects.values()]
+
+ def __len__(self) -> int:
+ """Return number of permanent objects."""
+ with self._lock:
+ return len(self._objects)
+
+ def __repr__(self) -> str:
+ with self._lock:
+ return f"ObjectDB(permanent={len(self._objects)}, pending={len(self._pending_objects)})"
+
+
+__all__ = ["ObjectDB"]
diff --git a/dimos/perception/detection/person_tracker.py b/dimos/perception/detection/person_tracker.py
index 6212080858..50082742f0 100644
--- a/dimos/perception/detection/person_tracker.py
+++ b/dimos/perception/detection/person_tracker.py
@@ -84,9 +84,7 @@ def detections_stream(self) -> Observable[ImageDetections2D]:
buffer_size=2.0,
).pipe(
ops.map(
- lambda pair: ImageDetections2D.from_ros_detection2d_array( # type: ignore[misc]
- *pair
- )
+ lambda pair: ImageDetections2D.from_ros_detection2d_array(*pair) # type: ignore[misc, arg-type]
)
)
)
diff --git a/dimos/perception/detection/reid/module.py b/dimos/perception/detection/reid/module.py
index 4e239da39a..f3f2a5a126 100644
--- a/dimos/perception/detection/reid/module.py
+++ b/dimos/perception/detection/reid/module.py
@@ -65,7 +65,7 @@ def detections_stream(self) -> Observable[ImageDetections2D]:
),
match_tolerance=0.0,
buffer_size=2.0,
- ).pipe(ops.map(lambda pair: ImageDetections2D.from_ros_detection2d_array(*pair))) # type: ignore[misc]
+ ).pipe(ops.map(lambda pair: ImageDetections2D.from_ros_detection2d_array(*pair))) # type: ignore[misc, arg-type]
)
@rpc
diff --git a/dimos/perception/detection/reid/test_embedding_id_system.py b/dimos/perception/detection/reid/test_embedding_id_system.py
index 3a0899c848..b9e6f591ee 100644
--- a/dimos/perception/detection/reid/test_embedding_id_system.py
+++ b/dimos/perception/detection/reid/test_embedding_id_system.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import numpy as np
import pytest
-import torch
from dimos.msgs.sensor_msgs import Image
from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem
@@ -25,8 +25,7 @@ def mobileclip_model():
"""Load MobileCLIP model once for all tests."""
from dimos.models.embedding.mobileclip import MobileCLIPModel
- model_path = get_data("models_mobileclip") / "mobileclip2_s0.pt"
- model = MobileCLIPModel(model_name="MobileCLIP2-S0", model_path=model_path)
+ model = MobileCLIPModel() # Uses default MobileCLIP2-S4
model.start()
return model
@@ -34,7 +33,11 @@ def mobileclip_model():
@pytest.fixture
def track_associator(mobileclip_model):
"""Create fresh EmbeddingIDSystem for each test."""
- return EmbeddingIDSystem(model=lambda: mobileclip_model, similarity_threshold=0.75)
+ return EmbeddingIDSystem(
+ model=lambda: mobileclip_model,
+ similarity_threshold=0.75,
+ min_embeddings_for_matching=1, # Allow matching with single embedding for tests
+ )
@pytest.fixture(scope="session")
@@ -52,39 +55,40 @@ def test_update_embedding_single(track_associator, mobileclip_model, test_image)
track_associator.update_embedding(track_id=1, new_embedding=embedding)
assert 1 in track_associator.track_embeddings
- assert track_associator.embedding_counts[1] == 1
+ assert len(track_associator.track_embeddings[1]) == 1
- # Verify embedding is on device and normalized
- emb_vec = track_associator.track_embeddings[1]
- assert isinstance(emb_vec, torch.Tensor)
- assert emb_vec.device.type in ["cuda", "cpu"]
- norm = torch.norm(emb_vec).item()
+ # Verify embedding is stored as numpy array and normalized
+ emb_vec = track_associator.track_embeddings[1][0]
+ assert isinstance(emb_vec, np.ndarray)
+ norm = np.linalg.norm(emb_vec)
assert abs(norm - 1.0) < 0.01, "Embedding should be normalized"
@pytest.mark.gpu
-def test_update_embedding_running_average(track_associator, mobileclip_model, test_image) -> None:
- """Test running average of embeddings."""
+def test_update_embedding_multiple(track_associator, mobileclip_model, test_image) -> None:
+ """Test storing multiple embeddings per track."""
embedding1 = mobileclip_model.embed(test_image)
embedding2 = mobileclip_model.embed(test_image)
# Add first embedding
track_associator.update_embedding(track_id=1, new_embedding=embedding1)
- first_vec = track_associator.track_embeddings[1].clone()
+ first_vec = track_associator.track_embeddings[1][0].copy()
# Add second embedding (same image, should be very similar)
track_associator.update_embedding(track_id=1, new_embedding=embedding2)
- avg_vec = track_associator.track_embeddings[1]
- assert track_associator.embedding_counts[1] == 2
+ # Should have 2 embeddings now
+ assert len(track_associator.track_embeddings[1]) == 2
- # Average should still be normalized
- norm = torch.norm(avg_vec).item()
- assert abs(norm - 1.0) < 0.01, "Average embedding should be normalized"
+ # Both should be normalized
+ for emb in track_associator.track_embeddings[1]:
+ norm = np.linalg.norm(emb)
+ assert abs(norm - 1.0) < 0.01, "Embedding should be normalized"
- # Average should be similar to both originals (same image)
- similarity1 = (first_vec @ avg_vec).item()
- assert similarity1 > 0.99, "Average should be very similar to original"
+ # Second embedding should be similar to first (same image)
+ second_vec = track_associator.track_embeddings[1][1]
+ similarity = float(np.dot(first_vec, second_vec))
+ assert similarity > 0.99, "Same image should produce very similar embeddings"
@pytest.mark.gpu
@@ -199,31 +203,33 @@ def test_associate_returns_cached(track_associator, mobileclip_model, test_image
@pytest.mark.gpu
-def test_associate_not_ready(track_associator) -> None:
- """Test that associate returns -1 for track without embedding."""
+def test_associate_no_embedding(track_associator) -> None:
+ """Test that associate creates new ID for track without embedding."""
+ # Track with no embedding gets assigned a new ID
long_term_id = track_associator.associate(track_id=999)
- assert long_term_id == -1, "Should return -1 for track without embedding"
+ assert long_term_id == 0, "Track without embedding should get new long_term_id"
+ assert track_associator.long_term_counter == 1
@pytest.mark.gpu
-def test_gpu_performance(track_associator, mobileclip_model, test_image) -> None:
- """Test that embeddings stay on GPU for performance."""
+def test_embeddings_stored_as_numpy(track_associator, mobileclip_model, test_image) -> None:
+ """Test that embeddings are stored as numpy arrays for efficient CPU comparisons."""
embedding = mobileclip_model.embed(test_image)
track_associator.update_embedding(track_id=1, new_embedding=embedding)
- # Embedding should stay on device
- emb_vec = track_associator.track_embeddings[1]
- assert isinstance(emb_vec, torch.Tensor)
- # Device comparison (handle "cuda" vs "cuda:0")
- expected_device = mobileclip_model.device
- assert emb_vec.device.type == torch.device(expected_device).type
+ # Embeddings should be stored as numpy arrays
+ emb_list = track_associator.track_embeddings[1]
+ assert isinstance(emb_list, list)
+ assert len(emb_list) == 1
+ assert isinstance(emb_list[0], np.ndarray)
- # Running average should happen on GPU
+ # Add more embeddings
embedding2 = mobileclip_model.embed(test_image)
track_associator.update_embedding(track_id=1, new_embedding=embedding2)
- avg_vec = track_associator.track_embeddings[1]
- assert avg_vec.device.type == torch.device(expected_device).type
+ assert len(track_associator.track_embeddings[1]) == 2
+ for emb in track_associator.track_embeddings[1]:
+ assert isinstance(emb, np.ndarray)
@pytest.mark.gpu
@@ -247,12 +253,12 @@ def test_multi_track_scenario(track_associator, mobileclip_model, test_image) ->
# Frame 2: Track 1 and Track 2 appear (different objects)
text_emb = mobileclip_model.embed_text("a dog")
- track_associator.update_embedding(1, emb1) # Update average
+ track_associator.update_embedding(1, emb1) # Update embedding
track_associator.update_embedding(2, text_emb)
track_associator.add_negative_constraints([1, 2]) # Co-occur = different
lt2 = track_associator.associate(2)
- # Track 2 should get different ID despite any similarity
+ # Track 2 should get different ID due to negative constraint
assert lt1 != lt2
# Frame 3: Track 1 disappears, Track 3 appears (same as Track 1)
diff --git a/dimos/perception/detection/type/detection2d/__init__.py b/dimos/perception/detection/type/detection2d/__init__.py
index ad3b7fa62e..8994d840b6 100644
--- a/dimos/perception/detection/type/detection2d/__init__.py
+++ b/dimos/perception/detection/type/detection2d/__init__.py
@@ -17,11 +17,13 @@
from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D
from dimos.perception.detection.type.detection2d.person import Detection2DPerson
from dimos.perception.detection.type.detection2d.point import Detection2DPoint
+from dimos.perception.detection.type.detection2d.seg import Detection2DSeg
__all__ = [
"Detection2D",
"Detection2DBBox",
"Detection2DPerson",
"Detection2DPoint",
+ "Detection2DSeg",
"ImageDetections2D",
]
diff --git a/dimos/perception/detection/type/detection2d/imageDetections2D.py b/dimos/perception/detection/type/detection2d/imageDetections2D.py
index 680f9dd117..34033a9c50 100644
--- a/dimos/perception/detection/type/detection2d/imageDetections2D.py
+++ b/dimos/perception/detection/type/detection2d/imageDetections2D.py
@@ -14,28 +14,32 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Generic
+from typing import TYPE_CHECKING, Any, Generic
from typing_extensions import TypeVar
from dimos.perception.detection.type.detection2d.base import Detection2D
from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox
+from dimos.perception.detection.type.detection2d.person import Detection2DPerson
+from dimos.perception.detection.type.detection2d.seg import Detection2DSeg
from dimos.perception.detection.type.imageDetections import ImageDetections
if TYPE_CHECKING:
- from dimos_lcm.vision_msgs import Detection2DArray
- from ultralytics.engine.results import Results # type: ignore[import-not-found]
+ from ultralytics.engine.results import Results
from dimos.msgs.sensor_msgs import Image
+ from dimos.msgs.vision_msgs import Detection2DArray
-# TypeVar with default - Detection2DBBox is the default when no type param given
T2D = TypeVar("T2D", bound=Detection2D, default=Detection2DBBox)
class ImageDetections2D(ImageDetections[T2D], Generic[T2D]):
@classmethod
- def from_ros_detection2d_array( # type: ignore[no-untyped-def]
- cls, image: Image, ros_detections: Detection2DArray, **kwargs
+ def from_ros_detection2d_array(
+ cls,
+ image: Image,
+ ros_detections: Detection2DArray,
+ **kwargs: Any,
) -> ImageDetections2D[Detection2DBBox]:
"""Convert from ROS Detection2DArray message to ImageDetections2D object."""
detections: list[Detection2DBBox] = []
@@ -47,24 +51,25 @@ def from_ros_detection2d_array( # type: ignore[no-untyped-def]
return ImageDetections2D(image=image, detections=detections)
@classmethod
- def from_ultralytics_result( # type: ignore[no-untyped-def]
- cls, image: Image, results: list[Results], **kwargs
+ def from_ultralytics_result(
+ cls,
+ image: Image,
+ results: list[Results],
) -> ImageDetections2D[Detection2DBBox]:
"""Create ImageDetections2D from ultralytics Results.
Dispatches to appropriate Detection2D subclass based on result type:
+ - If masks present: creates Detection2DSeg
- If keypoints present: creates Detection2DPerson
- Otherwise: creates Detection2DBBox
Args:
image: Source image
results: List of ultralytics Results objects
- **kwargs: Additional arguments passed to detection constructors
Returns:
ImageDetections2D containing appropriate detection types
"""
- from dimos.perception.detection.type.detection2d.person import Detection2DPerson
detections: list[Detection2DBBox] = []
for result in results:
@@ -74,7 +79,10 @@ def from_ultralytics_result( # type: ignore[no-untyped-def]
num_detections = len(result.boxes.xyxy)
for i in range(num_detections):
detection: Detection2DBBox
- if result.keypoints is not None:
+ if result.masks is not None:
+ # Segmentation detection with mask
+ detection = Detection2DSeg.from_ultralytics_result(result, i, image)
+ elif result.keypoints is not None:
# Pose detection with keypoints
detection = Detection2DPerson.from_ultralytics_result(result, i, image)
else:
diff --git a/dimos/perception/detection/type/detection2d/seg.py b/dimos/perception/detection/type/detection2d/seg.py
new file mode 100644
index 0000000000..21f8e8e689
--- /dev/null
+++ b/dimos/perception/detection/type/detection2d/seg.py
@@ -0,0 +1,204 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
+
+import cv2
+from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation
+from dimos_lcm.foxglove_msgs.Point2 import Point2
+import numpy as np
+import torch
+
+from dimos.msgs.foxglove_msgs.Color import Color
+from dimos.perception.detection.type.detection2d.bbox import Bbox, Detection2DBBox
+from dimos.types.timestamped import to_ros_stamp
+
+if TYPE_CHECKING:
+ from ultralytics.engine.results import Results
+
+ from dimos.msgs.sensor_msgs import Image
+
+
+@dataclass
+class Detection2DSeg(Detection2DBBox):
+ """Represents a detection with a segmentation mask."""
+
+ mask: np.ndarray[Any, np.dtype[np.uint8]] # Binary mask [H, W], uint8 0 or 255
+
+ @classmethod
+ def from_sam2_result(
+ cls,
+ mask: np.ndarray[Any, Any] | torch.Tensor,
+ obj_id: int,
+ image: Image,
+ class_id: int = 0,
+ name: str = "object",
+ confidence: float = 1.0,
+ ) -> Detection2DSeg:
+ """Create Detection2DSeg from SAM output (single object).
+
+ Args:
+ mask: Segmentation mask (logits or binary). Shape [H, W] or [1, H, W].
+ obj_id: Tracking ID of the object.
+ image: Source image.
+ class_id: Class ID (default 0).
+ name: Class name (default "object").
+ confidence: Confidence score (default 1.0).
+
+ Returns:
+ Detection2DSeg instance.
+ """
+ # Convert mask to numpy if tensor
+ if isinstance(mask, torch.Tensor):
+ mask = mask.detach().cpu().numpy()
+
+ # Handle dimensions (EdgeTAM might return [1, H, W] or [H, W])
+ if mask.ndim == 3:
+ mask = mask.squeeze()
+
+ # Binarize if it's logits (usually < 0 is background, > 0 is foreground)
+ # or if it's boolean
+ if mask.dtype == bool:
+ mask = mask.astype(np.uint8) * 255
+ elif np.issubdtype(mask.dtype, np.floating):
+ mask = (mask > 0.0).astype(np.uint8) * 255
+
+ # Calculate bbox
+ y_indices, x_indices = np.where(mask > 0)
+ if len(x_indices) > 0:
+ x1_val, y1_val = float(np.min(x_indices)), float(np.min(y_indices))
+ x2_val, y2_val = float(np.max(x_indices)), float(np.max(y_indices))
+ else:
+ x1_val = y1_val = x2_val = y2_val = 0.0
+
+ bbox = (x1_val, y1_val, x2_val, y2_val)
+
+ return cls(
+ bbox=bbox,
+ track_id=obj_id,
+ class_id=class_id,
+ confidence=confidence,
+ name=name,
+ ts=image.ts,
+ image=image,
+ mask=mask.astype(np.uint8), # type: ignore[arg-type]
+ )
+
+ @classmethod
+ def from_ultralytics_result(cls, result: Results, idx: int, image: Image) -> Detection2DSeg:
+ """Create Detection2DSeg from ultralytics Results object with segmentation mask.
+
+ Args:
+ result: Ultralytics Results object containing detection and mask data
+ idx: Index of the detection in the results
+ image: Source image
+
+ Returns:
+ Detection2DSeg instance
+ """
+ if result.boxes is None:
+ raise ValueError("Result has no boxes")
+
+ # Extract bounding box coordinates
+ bbox_array = result.boxes.xyxy[idx].cpu().numpy()
+ bbox: Bbox = (
+ float(bbox_array[0]),
+ float(bbox_array[1]),
+ float(bbox_array[2]),
+ float(bbox_array[3]),
+ )
+
+ # Extract confidence
+ confidence = float(result.boxes.conf[idx].cpu())
+
+ # Extract class ID and name
+ class_id = int(result.boxes.cls[idx].cpu())
+ if hasattr(result, "names") and result.names is not None:
+ if isinstance(result.names, dict):
+ name = result.names.get(class_id, f"class_{class_id}")
+ elif isinstance(result.names, list) and class_id < len(result.names):
+ name = result.names[class_id]
+ else:
+ name = f"class_{class_id}"
+ else:
+ name = f"class_{class_id}"
+
+ # Extract track ID if available
+ track_id = -1
+ if hasattr(result.boxes, "id") and result.boxes.id is not None:
+ track_id = int(result.boxes.id[idx].cpu())
+
+ # Extract mask
+ mask = np.zeros((image.height, image.width), dtype=np.uint8)
+ if result.masks is not None and idx < len(result.masks.data):
+ mask_tensor = result.masks.data[idx]
+ mask_np = mask_tensor.cpu().numpy()
+
+ # Resize mask to image size if needed
+ if mask_np.shape != (image.height, image.width):
+ mask_np = cv2.resize(
+ mask_np.astype(np.float32),
+ (image.width, image.height),
+ interpolation=cv2.INTER_LINEAR,
+ )
+
+ # Binarize mask
+ mask = (mask_np > 0.5).astype(np.uint8) * 255 # type: ignore[assignment]
+
+ return cls(
+ bbox=bbox,
+ track_id=track_id,
+ class_id=class_id,
+ confidence=confidence,
+ name=name,
+ ts=image.ts,
+ image=image,
+ mask=mask,
+ )
+
+ def to_points_annotation(self) -> list[PointsAnnotation]:
+ """Override to include mask outline."""
+ annotations = super().to_points_annotation()
+
+ # Find contours
+ contours, _ = cv2.findContours(self.mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours:
+ # Simplify contour to reduce points
+ epsilon = 0.005 * cv2.arcLength(contour, True)
+ approx = cv2.approxPolyDP(contour, epsilon, True)
+
+ points = []
+ for pt in approx:
+ points.append(Point2(x=float(pt[0][0]), y=float(pt[0][1])))
+
+ if len(points) < 3:
+ continue
+
+ annotations.append(
+ PointsAnnotation(
+ timestamp=to_ros_stamp(self.ts),
+ outline_color=Color.from_string(str(self.class_id), alpha=1.0, brightness=1.25),
+ fill_color=Color.from_string(str(self.track_id), alpha=0.4),
+ thickness=1.0,
+ points_length=len(points),
+ points=points,
+ type=PointsAnnotation.LINE_LOOP,
+ )
+ )
+
+ return annotations
diff --git a/dimos/perception/detection/type/detection3d/base.py b/dimos/perception/detection/type/detection3d/base.py
index d8cc430c44..b036584f3e 100644
--- a/dimos/perception/detection/type/detection3d/base.py
+++ b/dimos/perception/detection/type/detection3d/base.py
@@ -15,23 +15,22 @@
from __future__ import annotations
from abc import abstractmethod
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import TYPE_CHECKING
+from dimos.msgs.geometry_msgs import Transform
from dimos.perception.detection.type.detection2d import Detection2DBBox
if TYPE_CHECKING:
from dimos_lcm.sensor_msgs import CameraInfo
- from dimos.msgs.geometry_msgs import Transform
-
@dataclass
class Detection3D(Detection2DBBox):
"""Abstract base class for 3D detections."""
- transform: Transform
- frame_id: str
+ frame_id: str = ""
+ transform: Transform = field(default_factory=Transform.identity)
@classmethod
@abstractmethod
diff --git a/dimos/perception/detection/type/detection3d/bbox.py b/dimos/perception/detection/type/detection3d/bbox.py
index ac6f82a25e..cf7f4ea3cc 100644
--- a/dimos/perception/detection/type/detection3d/bbox.py
+++ b/dimos/perception/detection/type/detection3d/bbox.py
@@ -14,11 +14,15 @@
from __future__ import annotations
-from dataclasses import dataclass
+from dataclasses import dataclass, field
import functools
from typing import Any
-from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3
+from dimos_lcm.vision_msgs import ObjectHypothesis, ObjectHypothesisWithPose
+
+from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3
+from dimos.msgs.std_msgs import Header
+from dimos.msgs.vision_msgs import Detection3D
from dimos.perception.detection.type.detection2d import Detection2DBBox
@@ -29,11 +33,11 @@ class Detection3DBBox(Detection2DBBox):
Represents a 3D detection as an oriented bounding box in world space.
"""
- transform: Transform # Camera to world transform
- frame_id: str # Frame ID (e.g., "world", "map")
center: Vector3 # Center point in world frame
size: Vector3 # Width, height, depth
- orientation: tuple[float, float, float, float] # Quaternion (x, y, z, w)
+ transform: Transform | None = None # Camera to world transform
+ frame_id: str = "" # Frame ID (e.g., "world", "map")
+ orientation: Quaternion = field(default_factory=lambda: Quaternion(0.0, 0.0, 0.0, 1.0))
@functools.cached_property
def pose(self) -> PoseStamped:
@@ -48,8 +52,34 @@ def pose(self) -> PoseStamped:
orientation=self.orientation,
)
+ def to_detection3d_msg(self) -> Detection3D:
+ """Convert to ROS Detection3D message."""
+ msg = Detection3D()
+ msg.header = Header(self.ts, self.frame_id)
+
+ # Results
+ msg.results = [
+ ObjectHypothesisWithPose(
+ hypothesis=ObjectHypothesis(
+ class_id=str(self.class_id),
+ score=self.confidence,
+ )
+ )
+ ]
+
+ # Bounding Box
+ msg.bbox.center = Pose(
+ position=self.center,
+ orientation=self.orientation,
+ )
+ msg.bbox.size = self.size
+
+ return msg
+
def to_repr_dict(self) -> dict[str, Any]:
# Calculate distance from camera
+ if self.transform is None:
+ return super().to_repr_dict()
camera_pos = self.transform.translation
distance = (self.center - camera_pos).magnitude()
diff --git a/dimos/perception/detection/type/detection3d/object.py b/dimos/perception/detection/type/detection3d/object.py
new file mode 100644
index 0000000000..00d4d88661
--- /dev/null
+++ b/dimos/perception/detection/type/detection3d/object.py
@@ -0,0 +1,363 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+import time
+from typing import TYPE_CHECKING, Any
+import uuid
+
+import cv2
+from dimos_lcm.geometry_msgs import Pose
+import numpy as np
+import open3d as o3d # type: ignore[import-untyped]
+
+from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3
+from dimos.msgs.sensor_msgs import Image, PointCloud2
+from dimos.msgs.std_msgs import Header
+from dimos.msgs.vision_msgs import Detection3D as ROSDetection3D, Detection3DArray
+from dimos.perception.detection.type.detection2d.seg import Detection2DSeg
+from dimos.perception.detection.type.detection3d.base import Detection3D
+
+if TYPE_CHECKING:
+ from dimos_lcm.sensor_msgs import CameraInfo
+
+ from dimos.perception.detection.type.detection2d import ImageDetections2D
+
+
+@dataclass(kw_only=True)
+class Object(Detection3D):
+ """3D object detection combining bounding box and pointcloud representations.
+
+ Represents a detected object in 3D space with support for accumulating
+ multiple detections over time.
+ """
+
+ object_id: str = field(default_factory=lambda: uuid.uuid4().hex[:8])
+ center: Vector3
+ size: Vector3
+ pose: PoseStamped
+ pointcloud: PointCloud2
+ camera_transform: Transform | None = None
+ mask: np.ndarray[Any, np.dtype[np.uint8]] | None = None
+ detections_count: int = 1
+
+ def update_object(self, other: Object) -> None:
+ """Update this object with data from another detection.
+
+ Accumulates pointclouds by transforming the new pointcloud to world frame
+ and adding it to the existing pointcloud. Updates center and camera_transform,
+ and increments the detections_count.
+
+ Args:
+ other: Another Object instance with newer detection data.
+ """
+ # Accumulate pointclouds if transforms are available
+ if other.camera_transform is not None:
+ # Transform new pointcloud to world frame and add to existing
+ # transformed_pc = other.pointcloud.transform(other.camera_transform)
+ # self.pointcloud = self.pointcloud + transformed_pc
+
+ # Recompute center from accumulated pointcloud
+ self.pointcloud = other.pointcloud
+ pc_center = other.pointcloud.center
+ self.center = Vector3(pc_center.x, pc_center.y, pc_center.z)
+ else:
+ # No transform available, just replace
+ self.pointcloud = other.pointcloud
+ self.center = other.center
+
+ self.camera_transform = other.camera_transform
+ self.size = other.size
+ self.pose = other.pose
+ self.track_id = other.track_id
+ self.mask = other.mask
+ self.name = other.name
+ self.bbox = other.bbox
+ self.confidence = other.confidence
+ self.class_id = other.class_id
+ self.ts = other.ts
+ self.frame_id = other.frame_id
+ self.image = other.image
+ self.detections_count += 1
+
+ def get_oriented_bounding_box(self) -> Any:
+ """Get oriented bounding box of the pointcloud."""
+ return self.pointcloud.get_oriented_bounding_box()
+
+ def scene_entity_label(self) -> str:
+ """Get label for scene visualization."""
+ if self.detections_count > 1:
+ return f"{self.name} ({self.detections_count})"
+ return f"{self.track_id}/{self.name} ({self.confidence:.0%})"
+
+ def to_detection3d_msg(self) -> ROSDetection3D:
+ """Convert to ROS Detection3D message."""
+ obb = self.get_oriented_bounding_box() # type: ignore[no-untyped-call]
+ orientation = Quaternion.from_rotation_matrix(obb.R)
+
+ msg = ROSDetection3D()
+ msg.header = Header(self.ts, self.frame_id)
+ msg.id = str(self.track_id)
+ msg.bbox.center = Pose(
+ position=Vector3(obb.center[0], obb.center[1], obb.center[2]),
+ orientation=orientation,
+ )
+ msg.bbox.size = Vector3(obb.extent[0], obb.extent[1], obb.extent[2])
+
+ return msg
+
+ def agent_encode(self) -> dict[str, Any]:
+ """Encode for agent consumption."""
+ return {
+ "id": self.track_id,
+ "name": self.name,
+ "detections": self.detections_count,
+ "last_seen": f"{round(time.time() - self.ts)}s ago",
+ }
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert object to dictionary with all relevant data."""
+ return {
+ "object_id": self.object_id,
+ "track_id": self.track_id,
+ "class_id": self.class_id,
+ "name": self.name,
+ "mask": self.mask,
+ "pointcloud": self.pointcloud.as_numpy(),
+ "image": self.image.as_numpy() if self.image else None,
+ }
+
+ @classmethod
+ def from_2d_to_list(
+ cls,
+ detections_2d: ImageDetections2D[Detection2DSeg],
+ color_image: Image,
+ depth_image: Image,
+ camera_info: CameraInfo,
+ camera_transform: Transform | None = None,
+ depth_scale: float = 1.0,
+ depth_trunc: float = 10.0,
+ statistical_nb_neighbors: int = 10,
+ statistical_std_ratio: float = 0.5,
+ voxel_downsample: float = 0.005,
+ mask_erode_pixels: int = 3,
+ ) -> list[Object]:
+ """Create 3D Objects from 2D detections and RGBD images.
+
+ Uses Open3D's optimized RGBD projection for efficient processing.
+
+ Args:
+ detections_2d: 2D detections with segmentation masks
+ color_image: RGB color image
+ depth_image: Depth image (in meters if depth_scale=1.0)
+ camera_info: Camera intrinsics
+ camera_transform: Optional transform from camera frame to world frame.
+ If provided, pointclouds will be transformed to world frame.
+ depth_scale: Scale factor for depth (1.0 for meters, 1000.0 for mm)
+ depth_trunc: Maximum depth value in meters
+ statistical_nb_neighbors: Neighbors for statistical outlier removal
+ statistical_std_ratio: Std ratio for statistical outlier removal
+ voxel_downsample: Voxel size (meters) for downsampling before filtering. Set <= 0 to skip.
+ mask_erode_pixels: Number of pixels to erode the mask by to remove
+ noisy depth edge points. Set to 0 to disable.
+
+ Returns:
+ List of Object instances with pointclouds
+ """
+ color_cv = color_image.to_opencv()
+ if color_cv.ndim == 3 and color_cv.shape[2] == 3:
+ color_cv = cv2.cvtColor(color_cv, cv2.COLOR_BGR2RGB)
+
+ depth_cv = depth_image.to_opencv()
+ h, w = depth_cv.shape[:2]
+
+ # Build Open3D camera intrinsics
+ fx, fy = camera_info.K[0], camera_info.K[4]
+ cx, cy = camera_info.K[2], camera_info.K[5]
+ intrinsic_o3d = o3d.camera.PinholeCameraIntrinsic(w, h, fx, fy, cx, cy)
+
+ objects: list[Object] = []
+
+ for det in detections_2d.detections:
+ if isinstance(det, Detection2DSeg):
+ mask = det.mask
+ store_mask = det.mask
+ else:
+ mask = np.zeros((h, w), dtype=np.uint8)
+ x1, y1, x2, y2 = map(int, det.bbox)
+ x1, y1 = max(0, x1), max(0, y1)
+ x2, y2 = min(w, x2), min(h, y2)
+ mask[y1:y2, x1:x2] = 255
+ store_mask = mask
+
+ if mask_erode_pixels > 0:
+ mask_uint8 = mask.astype(np.uint8)
+ if mask_uint8.max() == 1:
+ mask_uint8 = mask_uint8 * 255 # type: ignore[assignment]
+ kernel_size = 2 * mask_erode_pixels + 1
+ erode_kernel = cv2.getStructuringElement(
+ cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
+ )
+ mask = cv2.erode(mask_uint8, erode_kernel) # type: ignore[assignment]
+
+ depth_masked = depth_cv.copy()
+ depth_masked[mask == 0] = 0
+
+ rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
+ o3d.geometry.Image(color_cv.astype(np.uint8)),
+ o3d.geometry.Image(depth_masked.astype(np.float32)),
+ depth_scale=depth_scale,
+ depth_trunc=depth_trunc,
+ convert_rgb_to_intensity=False,
+ )
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, intrinsic_o3d)
+
+ pc0 = PointCloud2(
+ pcd,
+ frame_id=depth_image.frame_id,
+ ts=depth_image.ts,
+ ).voxel_downsample(voxel_downsample)
+
+ pcd_filtered, _ = pc0.pointcloud.remove_statistical_outlier(
+ nb_neighbors=statistical_nb_neighbors,
+ std_ratio=statistical_std_ratio,
+ )
+
+ if len(pcd_filtered.points) < 10:
+ continue
+
+ pc = PointCloud2(
+ pcd_filtered,
+ frame_id=depth_image.frame_id,
+ ts=depth_image.ts,
+ )
+
+ # Transform pointcloud to world frame if camera_transform is provided
+ if camera_transform is not None:
+ pc = pc.transform(camera_transform)
+ frame_id = camera_transform.frame_id
+ else:
+ frame_id = depth_image.frame_id
+
+ # Compute center from pointcloud
+ obb = pc.pointcloud.get_oriented_bounding_box()
+ center = Vector3(obb.center[0], obb.center[1], obb.center[2])
+ size = Vector3(obb.extent[0], obb.extent[1], obb.extent[2])
+ orientation = Quaternion.from_rotation_matrix(obb.R)
+ pose = PoseStamped(
+ ts=det.ts,
+ frame_id=frame_id,
+ position=center,
+ orientation=orientation,
+ )
+
+ objects.append(
+ cls(
+ bbox=det.bbox,
+ track_id=det.track_id,
+ class_id=det.class_id,
+ confidence=det.confidence,
+ name=det.name,
+ ts=det.ts,
+ image=det.image,
+ frame_id=frame_id,
+ pointcloud=pc,
+ center=center,
+ size=size,
+ pose=pose,
+ camera_transform=camera_transform,
+ mask=store_mask,
+ )
+ )
+
+ return objects
+
+
+def aggregate_pointclouds(objects: list[Object]) -> PointCloud2:
+ """Aggregate all object pointclouds into a single colored pointcloud.
+
+ Each object's points are colored based on its track_id.
+
+ Args:
+ objects: List of Object instances with pointclouds
+
+ Returns:
+ Combined PointCloud2 with all points colored by object (empty if no points).
+ """
+ if not objects:
+ return PointCloud2(pointcloud=o3d.geometry.PointCloud(), frame_id="", ts=0.0)
+
+ all_points = []
+ all_colors = []
+
+ for _i, obj in enumerate(objects):
+ points, colors = obj.pointcloud.as_numpy()
+ if len(points) == 0:
+ continue
+
+ try:
+ seed = int(obj.object_id, 16)
+ except (ValueError, TypeError):
+ seed = abs(hash(obj.object_id))
+ np.random.seed(abs(seed) % (2**32 - 1))
+ track_color = np.random.randint(50, 255, 3) / 255.0
+
+ if colors is not None:
+ blended = np.clip(0.6 * colors + 0.4 * track_color, 0.0, 1.0)
+ else:
+ blended = np.tile(track_color, (len(points), 1))
+
+ all_points.append(points)
+ all_colors.append(blended)
+
+ if not all_points:
+ return PointCloud2(
+ pointcloud=o3d.geometry.PointCloud(), frame_id=objects[0].frame_id, ts=objects[0].ts
+ )
+
+ combined_points = np.vstack(all_points)
+ combined_colors = np.vstack(all_colors)
+
+ pc = PointCloud2.from_numpy(
+ combined_points,
+ frame_id=objects[0].frame_id,
+ timestamp=objects[0].ts,
+ )
+ pcd = pc.pointcloud
+ pcd.colors = o3d.utility.Vector3dVector(combined_colors)
+ pc.pointcloud = pcd
+
+ return pc
+
+
+def to_detection3d_array(objects: list[Object]) -> Detection3DArray:
+ """Convert a list of Objects to a ROS Detection3DArray message.
+
+ Args:
+ objects: List of Object instances
+
+ Returns:
+ Detection3DArray ROS message
+ """
+ array = Detection3DArray()
+
+ if objects:
+ array.header = Header(objects[0].ts, objects[0].frame_id)
+
+ for obj in objects:
+ array.detections.append(obj.to_detection3d_msg())
+
+ return array
diff --git a/dimos/perception/detection/type/detection3d/pointcloud.py b/dimos/perception/detection/type/detection3d/pointcloud.py
index fd924a6564..7edceb17a5 100644
--- a/dimos/perception/detection/type/detection3d/pointcloud.py
+++ b/dimos/perception/detection/type/detection3d/pointcloud.py
@@ -14,7 +14,7 @@
from __future__ import annotations
-from dataclasses import dataclass
+from dataclasses import dataclass, field
import functools
from typing import TYPE_CHECKING, Any
@@ -52,7 +52,7 @@
@dataclass
class Detection3DPC(Detection3D):
- pointcloud: PointCloud2
+ pointcloud: PointCloud2 = field(default_factory=PointCloud2)
@functools.cached_property
def center(self) -> Vector3:
diff --git a/dimos/perception/detection2d/utils.py b/dimos/perception/detection2d/utils.py
deleted file mode 100644
index a505eef7c8..0000000000
--- a/dimos/perception/detection2d/utils.py
+++ /dev/null
@@ -1,309 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections.abc import Sequence
-
-import cv2
-import numpy as np
-
-
-def filter_detections( # type: ignore[no-untyped-def]
- bboxes,
- track_ids,
- class_ids,
- confidences,
- names: Sequence[str],
- class_filter=None,
- name_filter=None,
- track_id_filter=None,
-):
- """
- Filter detection results based on class IDs, names, and/or tracking IDs.
-
- Args:
- bboxes: List of bounding boxes [x1, y1, x2, y2]
- track_ids: List of tracking IDs
- class_ids: List of class indices
- confidences: List of detection confidences
- names: List of class names
- class_filter: List/set of class IDs to keep, or None to keep all
- name_filter: List/set of class names to keep, or None to keep all
- track_id_filter: List/set of track IDs to keep, or None to keep all
-
- Returns:
- tuple: (filtered_bboxes, filtered_track_ids, filtered_class_ids,
- filtered_confidences, filtered_names)
- """
- # Convert filters to sets for efficient lookup
- if class_filter is not None:
- class_filter = set(class_filter)
- if name_filter is not None:
- name_filter = set(name_filter)
- if track_id_filter is not None:
- track_id_filter = set(track_id_filter)
-
- # Initialize lists for filtered results
- filtered_bboxes = []
- filtered_track_ids = []
- filtered_class_ids = []
- filtered_confidences = []
- filtered_names = []
-
- # Filter detections
- for bbox, track_id, class_id, conf, name in zip(
- bboxes, track_ids, class_ids, confidences, names, strict=False
- ):
- # Check if detection passes all specified filters
- keep = True
-
- if class_filter is not None:
- keep = keep and (class_id in class_filter)
-
- if name_filter is not None:
- keep = keep and (name in name_filter)
-
- if track_id_filter is not None:
- keep = keep and (track_id in track_id_filter)
-
- # If detection passes all filters, add it to results
- if keep:
- filtered_bboxes.append(bbox)
- filtered_track_ids.append(track_id)
- filtered_class_ids.append(class_id)
- filtered_confidences.append(conf)
- filtered_names.append(name)
-
- return (
- filtered_bboxes,
- filtered_track_ids,
- filtered_class_ids,
- filtered_confidences,
- filtered_names,
- )
-
-
-def extract_detection_results(result, class_filter=None, name_filter=None, track_id_filter=None): # type: ignore[no-untyped-def]
- """
- Extract and optionally filter detection information from a YOLO result object.
-
- Args:
- result: Ultralytics result object
- class_filter: List/set of class IDs to keep, or None to keep all
- name_filter: List/set of class names to keep, or None to keep all
- track_id_filter: List/set of track IDs to keep, or None to keep all
-
- Returns:
- tuple: (bboxes, track_ids, class_ids, confidences, names)
- - bboxes: list of [x1, y1, x2, y2] coordinates
- - track_ids: list of tracking IDs
- - class_ids: list of class indices
- - confidences: list of detection confidences
- - names: list of class names
- """
- bboxes = [] # type: ignore[var-annotated]
- track_ids = [] # type: ignore[var-annotated]
- class_ids = [] # type: ignore[var-annotated]
- confidences = [] # type: ignore[var-annotated]
- names = [] # type: ignore[var-annotated]
-
- if result.boxes is None:
- return bboxes, track_ids, class_ids, confidences, names
-
- for box in result.boxes:
- # Extract bounding box coordinates
- x1, y1, x2, y2 = box.xyxy[0].tolist()
-
- # Extract tracking ID if available
- track_id = -1
- if hasattr(box, "id") and box.id is not None:
- track_id = int(box.id[0].item())
-
- # Extract class information
- cls_idx = int(box.cls[0])
- name = result.names[cls_idx]
-
- # Extract confidence
- conf = float(box.conf[0])
-
- # Check filters before adding to results
- keep = True
- if class_filter is not None:
- keep = keep and (cls_idx in class_filter)
- if name_filter is not None:
- keep = keep and (name in name_filter)
- if track_id_filter is not None:
- keep = keep and (track_id in track_id_filter)
-
- if keep:
- bboxes.append([x1, y1, x2, y2])
- track_ids.append(track_id)
- class_ids.append(cls_idx)
- confidences.append(conf)
- names.append(name)
-
- return bboxes, track_ids, class_ids, confidences, names
-
-
-def plot_results( # type: ignore[no-untyped-def]
- image, bboxes, track_ids, class_ids, confidences, names: Sequence[str], alpha: float = 0.5
-):
- """
- Draw bounding boxes and labels on the image.
-
- Args:
- image: Original input image
- bboxes: List of bounding boxes [x1, y1, x2, y2]
- track_ids: List of tracking IDs
- class_ids: List of class indices
- confidences: List of detection confidences
- names: List of class names
- alpha: Transparency of the overlay
-
- Returns:
- Image with visualized detections
- """
- vis_img = image.copy()
-
- for bbox, track_id, conf, name in zip(bboxes, track_ids, confidences, names, strict=False):
- # Generate consistent color based on track_id or class name
- if track_id != -1:
- np.random.seed(track_id)
- else:
- np.random.seed(hash(name) % 100000)
- color = np.random.randint(0, 255, (3,), dtype=np.uint8)
- np.random.seed(None)
-
- # Draw bounding box
- x1, y1, x2, y2 = map(int, bbox)
- cv2.rectangle(vis_img, (x1, y1), (x2, y2), color.tolist(), 2)
-
- # Prepare label text
- if track_id != -1:
- label = f"ID:{track_id} {name} {conf:.2f}"
- else:
- label = f"{name} {conf:.2f}"
-
- # Calculate text size for background rectangle
- (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
-
- # Draw background rectangle for text
- cv2.rectangle(vis_img, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1)
-
- # Draw text with white color for better visibility
- cv2.putText(
- vis_img, label, (x1 + 2, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1
- )
-
- return vis_img
-
-
-def calculate_depth_from_bbox(depth_map, bbox): # type: ignore[no-untyped-def]
- """
- Calculate the average depth of an object within a bounding box.
- Uses the 25th to 75th percentile range to filter outliers.
-
- Args:
- depth_map: The depth map
- bbox: Bounding box in format [x1, y1, x2, y2]
-
- Returns:
- float: Average depth in meters, or None if depth estimation fails
- """
- try:
- # Extract region of interest from the depth map
- x1, y1, x2, y2 = map(int, bbox)
- roi_depth = depth_map[y1:y2, x1:x2]
-
- if roi_depth.size == 0:
- return None
-
- # Calculate 25th and 75th percentile to filter outliers
- p25 = np.percentile(roi_depth, 25)
- p75 = np.percentile(roi_depth, 75)
-
- # Filter depth values within this range
- filtered_depth = roi_depth[(roi_depth >= p25) & (roi_depth <= p75)]
-
- # Calculate average depth (convert to meters)
- if filtered_depth.size > 0:
- return np.mean(filtered_depth) / 1000.0 # Convert mm to meters
-
- return None
- except Exception as e:
- print(f"Error calculating depth from bbox: {e}")
- return None
-
-
-def calculate_distance_angle_from_bbox(bbox, depth: int, camera_intrinsics): # type: ignore[no-untyped-def]
- """
- Calculate distance and angle to object center based on bbox and depth.
-
- Args:
- bbox: Bounding box [x1, y1, x2, y2]
- depth: Depth value in meters
- camera_intrinsics: List [fx, fy, cx, cy] with camera parameters
-
- Returns:
- tuple: (distance, angle) in meters and radians
- """
- if camera_intrinsics is None:
- raise ValueError("Camera intrinsics required for distance calculation")
-
- # Extract camera parameters
- fx, _fy, cx, _cy = camera_intrinsics
-
- # Calculate center of bounding box in pixels
- x1, y1, x2, y2 = bbox
- center_x = (x1 + x2) / 2
- (y1 + y2) / 2
-
- # Calculate normalized image coordinates
- x_norm = (center_x - cx) / fx
-
- # Calculate angle (positive to the right)
- angle = np.arctan(x_norm)
-
- # Calculate distance using depth and angle
- distance = depth / np.cos(angle) if np.cos(angle) != 0 else depth
-
- return distance, angle
-
-
-def calculate_object_size_from_bbox(bbox, depth: int, camera_intrinsics): # type: ignore[no-untyped-def]
- """
- Estimate physical width and height of object in meters.
-
- Args:
- bbox: Bounding box [x1, y1, x2, y2]
- depth: Depth value in meters
- camera_intrinsics: List [fx, fy, cx, cy] with camera parameters
-
- Returns:
- tuple: (width, height) in meters
- """
- if camera_intrinsics is None:
- return 0.0, 0.0
-
- fx, fy, _, _ = camera_intrinsics
-
- # Calculate bbox dimensions in pixels
- x1, y1, x2, y2 = bbox
- width_px = x2 - x1
- height_px = y2 - y1
-
- # Convert to meters using similar triangles and depth
- width_m = (width_px * depth) / fx
- height_m = (height_px * depth) / fy
-
- return width_m, height_m
diff --git a/dimos/hardware/manipulators/base/tests/__init__.py b/dimos/perception/experimental/__init__.py
similarity index 87%
rename from dimos/hardware/manipulators/base/tests/__init__.py
rename to dimos/perception/experimental/__init__.py
index f863fa5120..39ef33521d 100644
--- a/dimos/hardware/manipulators/base/tests/__init__.py
+++ b/dimos/perception/experimental/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2025-2026 Dimensional Inc.
+# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Tests for manipulator base module."""
+"""Experimental perception modules."""
diff --git a/dimos/perception/experimental/temporal_memory/README.md b/dimos/perception/experimental/temporal_memory/README.md
new file mode 100644
index 0000000000..9ef5f6cb22
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/README.md
@@ -0,0 +1,32 @@
+Temporal memory runs "Temporal/Spatial RAG" on streamed videos building an continuous entity-based
+memory over time. It uses a VLM to extract evidence in sliding windows, tracks
+entities across windows, maintains a rolling summary, and stores relations in a graph network.
+
+Methodology
+1) Sample frames at a target FPS and analyze them in sliding windows.
+2) Extract dense evidence with a VLM (caption + entities + relations).
+3) Update rolling summary for global context.
+4) Persist per-window evidence + entity graph for query-time context.
+
+Setup
+- Put your OpenAI key in `.env`:
+ `OPENAI_API_KEY=...`
+- Install dimensional dependencies
+
+Quickstart
+To run: `dimos --replay run unitree-go2-temporal-memory`
+
+In another terminal: `humancli` to chat with the agent and run memory queries.
+
+Artifacts
+By default, artifacts are written under `assets/temporal_memory`:
+- `evidence.jsonl` (window evidence: captions, entities, relations)
+- `state.json` (rolling summary + roster state)
+- `entities.json` (current entity roster)
+- `frames_index.jsonl` (timestamps for saved frames; written on stop)
+- `entity_graph.db` (SQLite graph of relations/distances)
+
+Notes
+- Evidence is extracted in sliding windows, so queries can refer to recent or past entities.
+- Distance estimation can run in the background to enrich graph relations.
+- If you want a different output directory, set `TemporalMemoryConfig(output_dir=...)`.
diff --git a/dimos/perception/experimental/temporal_memory/__init__.py b/dimos/perception/experimental/temporal_memory/__init__.py
new file mode 100644
index 0000000000..3cc61601ce
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Temporal memory package."""
+
+from .temporal_memory import Frame, TemporalMemory, TemporalMemoryConfig, temporal_memory
+
+__all__ = [
+ "Frame",
+ "TemporalMemory",
+ "TemporalMemoryConfig",
+ "temporal_memory",
+]
diff --git a/dimos/perception/experimental/temporal_memory/clip_filter.py b/dimos/perception/experimental/temporal_memory/clip_filter.py
new file mode 100644
index 0000000000..8faac3fad8
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/clip_filter.py
@@ -0,0 +1,171 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""CLIP-based frame filtering for selecting diverse frames from video windows."""
+
+from typing import Any, cast
+
+import numpy as np
+
+from dimos.msgs.sensor_msgs import Image
+from dimos.utils.logging_config import setup_logger
+
+logger = setup_logger()
+
+try:
+ import torch # type: ignore
+
+ from dimos.models.embedding.clip import CLIPModel # type: ignore
+
+ CLIP_AVAILABLE = True
+except ImportError as e:
+ CLIP_AVAILABLE = False
+ logger.info(f"CLIP unavailable ({e}), using simple frame sampling")
+
+
+def _get_image_data(image: Image) -> np.ndarray[Any, Any]:
+ """Extract numpy array from Image."""
+ if not hasattr(image, "data"):
+ raise AttributeError(f"Image missing .data attribute: {type(image)}")
+ return cast("np.ndarray[Any, Any]", image.data)
+
+
+if CLIP_AVAILABLE:
+
+ class CLIPFrameFilter:
+ """Filter video frames using CLIP embeddings for diversity."""
+
+ def __init__(self, model_name: str = "ViT-B/32", device: str | None = None):
+ if not CLIP_AVAILABLE:
+ raise ImportError("CLIP not available. Install transformers[torch].")
+
+ resolved_name = (
+ "openai/clip-vit-base-patch32" if model_name == "ViT-B/32" else model_name
+ )
+ if device is None:
+ self._model = CLIPModel(model_name=resolved_name)
+ else:
+ self._model = CLIPModel(model_name=resolved_name, device=device)
+ logger.info(f"Loading CLIP {resolved_name} on {self._model.device}")
+
+ def _encode_images(self, images: list[Image]) -> "torch.Tensor":
+ """Encode images using CLIP."""
+ embeddings = self._model.embed(*images)
+ if not isinstance(embeddings, list):
+ embeddings = [embeddings]
+ vectors = [e.to_torch(self._model.device) for e in embeddings]
+ return torch.stack(vectors)
+
+ def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]:
+ """Select diverse frames using greedy farthest-point sampling in CLIP space."""
+ if len(frames) <= max_frames:
+ return frames
+
+ embeddings = self._encode_images([f.image for f in frames])
+
+ # Greedy farthest-point sampling
+ selected_indices = [0] # Always include first frame
+ remaining_indices = list(range(1, len(frames)))
+
+ while len(selected_indices) < max_frames and remaining_indices:
+ # Compute similarities: (num_remaining, num_selected)
+ similarities = embeddings[remaining_indices] @ embeddings[selected_indices].T
+ # Find max similarity for each remaining frame
+ max_similarities = similarities.max(dim=1)[0]
+ # Select frame most different from all selected
+ best_idx = int(max_similarities.argmin().item())
+
+ selected_indices.append(remaining_indices[best_idx])
+ remaining_indices.pop(best_idx)
+
+ return [frames[i] for i in sorted(selected_indices)]
+
+ def close(self) -> None:
+ """Clean up CLIP model."""
+ if hasattr(self, "_model"):
+ self._model.stop()
+ del self._model
+
+
+def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list[Any]:
+ """Fallback frame selection: uniform sampling across window."""
+ if len(frames) <= max_frames:
+ return frames
+ indices = [int(i * len(frames) / max_frames) for i in range(max_frames)]
+ return [frames[i] for i in indices]
+
+
+def adaptive_keyframes(
+ frames: list[Any],
+ min_frames: int = 3,
+ max_frames: int = 5,
+ change_threshold: float = 15.0,
+) -> list[Any]:
+ """Select frames based on visual change, adaptive count."""
+ if len(frames) <= min_frames:
+ return frames
+
+ # Compute frame-to-frame differences
+ try:
+ diffs = [
+ np.abs(
+ _get_image_data(frames[i].image).astype(float)
+ - _get_image_data(frames[i - 1].image).astype(float)
+ ).mean()
+ for i in range(1, len(frames))
+ ]
+ except (AttributeError, ValueError) as e:
+ logger.warning(f"Failed to compute frame diffs: {e}. Falling back to uniform sampling.")
+ return select_diverse_frames_simple(frames, max_frames)
+
+ total_motion = sum(diffs)
+ n_frames = int(np.clip(total_motion / change_threshold, min_frames, max_frames))
+
+ # Always include first and last
+ keyframe_indices = {0, len(frames) - 1}
+
+ # Add peaks in diff signal
+ for i in range(1, len(diffs) - 1):
+ if (
+ diffs[i] > diffs[i - 1]
+ and diffs[i] > diffs[i + 1]
+ and diffs[i] > change_threshold * 0.5
+ ):
+ keyframe_indices.add(i + 1)
+
+ # Adjust count
+ if len(keyframe_indices) > n_frames:
+ # Keep first, last, and highest-diff peaks
+ middle = [i for i in keyframe_indices if i not in (0, len(frames) - 1)]
+ middle_by_diff = sorted(middle, key=lambda i: diffs[i - 1], reverse=True)
+ keyframe_indices = {0, len(frames) - 1, *middle_by_diff[: n_frames - 2]}
+ elif len(keyframe_indices) < n_frames:
+ # Fill uniformly from remaining
+ needed = n_frames - len(keyframe_indices)
+ candidates = sorted(set(range(len(frames))) - keyframe_indices)
+ if candidates:
+ step = max(1, len(candidates) // (needed + 1))
+ keyframe_indices.update(candidates[::step][:needed])
+
+ return [frames[i] for i in sorted(keyframe_indices)]
+
+
+__all__ = [
+ "CLIP_AVAILABLE",
+ "adaptive_keyframes",
+ "select_diverse_frames_simple",
+]
+
+if CLIP_AVAILABLE:
+ __all__.append("CLIPFrameFilter")
diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py
new file mode 100644
index 0000000000..7109459f40
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py
@@ -0,0 +1,1018 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Entity Graph Database for storing and querying entity relationships.
+
+Maintains three types of graphs:
+1. Relations Graph: Interactions between entities (holds, looks_at, talks_to, etc.)
+2. Distance Graph: Spatial distances between entities
+3. Semantic Graph: Conceptual relationships (goes_with, part_of, used_for, etc.)
+
+All graphs share the same entity nodes but have different edge types.
+"""
+
+import json
+from pathlib import Path
+import sqlite3
+import threading
+from typing import TYPE_CHECKING, Any
+
+from dimos.utils.logging_config import setup_logger
+
+if TYPE_CHECKING:
+ from dimos.models.vl.base import VlModel
+ from dimos.msgs.sensor_msgs import Image
+
+logger = setup_logger()
+
+
+class EntityGraphDB:
+ """
+ SQLite-based graph database for entity relationships.
+
+ Thread-safe implementation using connection-per-thread pattern.
+ All graphs share the same entity nodes but maintain separate edge tables.
+ """
+
+ def __init__(self, db_path: str | Path) -> None:
+ """
+ Initialize the entity graph database.
+
+ Args:
+ db_path: Path to the SQLite database file
+ """
+ self.db_path = Path(db_path)
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Thread-local storage for connections
+ self._local = threading.local()
+
+ # Initialize schema
+ self._init_schema()
+
+ logger.info(f"EntityGraphDB initialized at {self.db_path}")
+
+ def _get_connection(self) -> sqlite3.Connection:
+ """Get thread-local database connection."""
+ if not hasattr(self._local, "conn"):
+ self._local.conn = sqlite3.connect(str(self.db_path))
+ self._local.conn.row_factory = sqlite3.Row
+ return self._local.conn # type: ignore
+
+ def _init_schema(self) -> None:
+ """Initialize database schema."""
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ # Entities table (shared nodes)
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS entities (
+ entity_id TEXT PRIMARY KEY,
+ entity_type TEXT NOT NULL,
+ descriptor TEXT,
+ first_seen_ts REAL NOT NULL,
+ last_seen_ts REAL NOT NULL,
+ metadata TEXT
+ )
+ """)
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_entities_first_seen ON entities(first_seen_ts)"
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_entities_last_seen ON entities(last_seen_ts)"
+ )
+
+ # Relations table (Graph 1: Interactions)
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS relations (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ relation_type TEXT NOT NULL,
+ subject_id TEXT NOT NULL,
+ object_id TEXT NOT NULL,
+ confidence REAL DEFAULT 1.0,
+ timestamp_s REAL NOT NULL,
+ evidence TEXT,
+ notes TEXT,
+ FOREIGN KEY (subject_id) REFERENCES entities(entity_id),
+ FOREIGN KEY (object_id) REFERENCES entities(entity_id)
+ )
+ """)
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_subject ON relations(subject_id)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_object ON relations(object_id)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_type ON relations(relation_type)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_time ON relations(timestamp_s)")
+
+ # Distances table (Graph 2: Spatial)
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS distances (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ entity_a_id TEXT NOT NULL,
+ entity_b_id TEXT NOT NULL,
+ distance_meters REAL,
+ distance_category TEXT,
+ confidence REAL DEFAULT 1.0,
+ timestamp_s REAL NOT NULL,
+ method TEXT,
+ FOREIGN KEY (entity_a_id) REFERENCES entities(entity_id),
+ FOREIGN KEY (entity_b_id) REFERENCES entities(entity_id)
+ )
+ """)
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_distances_pair ON distances(entity_a_id, entity_b_id)"
+ )
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_distances_time ON distances(timestamp_s)")
+
+ # Semantic relations table (Graph 3: Knowledge)
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS semantic_relations (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ relation_type TEXT NOT NULL,
+ entity_a_id TEXT NOT NULL,
+ entity_b_id TEXT NOT NULL,
+ confidence REAL DEFAULT 1.0,
+ learned_from TEXT,
+ first_observed_ts REAL NOT NULL,
+ last_observed_ts REAL NOT NULL,
+ observation_count INTEGER DEFAULT 1,
+ FOREIGN KEY (entity_a_id) REFERENCES entities(entity_id),
+ FOREIGN KEY (entity_b_id) REFERENCES entities(entity_id)
+ )
+ """)
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_semantic_pair ON semantic_relations(entity_a_id, entity_b_id)"
+ )
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_semantic_type ON semantic_relations(relation_type)"
+ )
+
+ conn.commit()
+
+ def upsert_entity(
+ self,
+ entity_id: str,
+ entity_type: str,
+ descriptor: str,
+ timestamp_s: float,
+ metadata: dict[str, Any] | None = None,
+ ) -> None:
+ """
+ Insert or update an entity.
+
+ Args:
+ entity_id: Unique entity identifier (e.g., "E1")
+ entity_type: Type of entity (person, object, location, etc.)
+ descriptor: Text description of the entity
+ timestamp_s: Timestamp when entity was observed
+ metadata: Optional additional metadata
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ metadata_json = json.dumps(metadata) if metadata else None
+
+ cursor.execute(
+ """
+ INSERT INTO entities (entity_id, entity_type, descriptor, first_seen_ts, last_seen_ts, metadata)
+ VALUES (?, ?, ?, ?, ?, ?)
+ ON CONFLICT(entity_id) DO UPDATE SET
+ last_seen_ts = ?,
+ descriptor = COALESCE(excluded.descriptor, descriptor),
+ metadata = COALESCE(excluded.metadata, metadata)
+ """,
+ (
+ entity_id,
+ entity_type,
+ descriptor,
+ timestamp_s,
+ timestamp_s,
+ metadata_json,
+ timestamp_s,
+ ),
+ )
+
+ conn.commit()
+ logger.debug(f"Upserted entity {entity_id} (type={entity_type})")
+
+ def get_entity(self, entity_id: str) -> dict[str, Any] | None:
+ """
+ Get an entity by ID.
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute("SELECT * FROM entities WHERE entity_id = ?", (entity_id,))
+ row = cursor.fetchone()
+
+ if row is None:
+ return None
+
+ return {
+ "entity_id": row["entity_id"],
+ "entity_type": row["entity_type"],
+ "descriptor": row["descriptor"],
+ "first_seen_ts": row["first_seen_ts"],
+ "last_seen_ts": row["last_seen_ts"],
+ "metadata": json.loads(row["metadata"]) if row["metadata"] else None,
+ }
+
+ def get_all_entities(self, entity_type: str | None = None) -> list[dict[str, Any]]:
+ """Get all entities, optionally filtered by type."""
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ if entity_type:
+ cursor.execute(
+ "SELECT * FROM entities WHERE entity_type = ? ORDER BY last_seen_ts DESC",
+ (entity_type,),
+ )
+ else:
+ cursor.execute("SELECT * FROM entities ORDER BY last_seen_ts DESC")
+
+ rows = cursor.fetchall()
+ return [
+ {
+ "entity_id": row["entity_id"],
+ "entity_type": row["entity_type"],
+ "descriptor": row["descriptor"],
+ "first_seen_ts": row["first_seen_ts"],
+ "last_seen_ts": row["last_seen_ts"],
+ "metadata": json.loads(row["metadata"]) if row["metadata"] else None,
+ }
+ for row in rows
+ ]
+
+ def get_entities_by_time(
+ self,
+ time_window: tuple[float, float],
+ first_seen: bool = True,
+ ) -> list[dict[str, Any]]:
+ """Get entities first/last seen within a time window.
+
+ Args:
+ time_window: (start_ts, end_ts) tuple in seconds
+ first_seen: If True, filter by first_seen_ts. If False, filter by last_seen_ts.
+
+ Returns:
+ List of entities seen within the time window
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ ts_field = "first_seen_ts" if first_seen else "last_seen_ts"
+ cursor.execute(
+ f"SELECT * FROM entities WHERE {ts_field} BETWEEN ? AND ? ORDER BY {ts_field} DESC",
+ time_window,
+ )
+
+ rows = cursor.fetchall()
+ return [
+ {
+ "entity_id": row["entity_id"],
+ "entity_type": row["entity_type"],
+ "descriptor": row["descriptor"],
+ "first_seen_ts": row["first_seen_ts"],
+ "last_seen_ts": row["last_seen_ts"],
+ "metadata": json.loads(row["metadata"]) if row["metadata"] else None,
+ }
+ for row in rows
+ ]
+
+ def add_relation(
+ self,
+ relation_type: str,
+ subject_id: str,
+ object_id: str,
+ confidence: float,
+ timestamp_s: float,
+ evidence: list[str] | None = None,
+ notes: str | None = None,
+ ) -> None:
+ """
+ Add a relation between two entities.
+
+ Args:
+ relation_type: Type of relation (holds, looks_at, talks_to, etc.)
+ subject_id: Subject entity ID
+ object_id: Object entity ID
+ confidence: Confidence score (0.0 to 1.0)
+ timestamp_s: Timestamp when relation was observed
+ evidence: Optional list of evidence strings
+ notes: Optional notes
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ evidence_json = json.dumps(evidence) if evidence else None
+
+ cursor.execute(
+ """
+ INSERT INTO relations (relation_type, subject_id, object_id, confidence, timestamp_s, evidence, notes)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (relation_type, subject_id, object_id, confidence, timestamp_s, evidence_json, notes),
+ )
+
+ conn.commit()
+ logger.debug(f"Added relation: {subject_id} --{relation_type}--> {object_id}")
+
+ def get_relations_for_entity(
+ self,
+ entity_id: str,
+ relation_type: str | None = None,
+ time_window: tuple[float, float] | None = None,
+ ) -> list[dict[str, Any]]:
+ """
+ Get all relations involving an entity.
+
+ Args:
+ entity_id: Entity ID
+ relation_type: Optional filter by relation type
+ time_window: Optional (start_ts, end_ts) tuple
+
+ Returns:
+ List of relation dicts
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ query = """
+ SELECT * FROM relations
+ WHERE (subject_id = ? OR object_id = ?)
+ """
+ params: list[Any] = [entity_id, entity_id]
+
+ if relation_type:
+ query += " AND relation_type = ?"
+ params.append(relation_type)
+
+ if time_window:
+ query += " AND timestamp_s BETWEEN ? AND ?"
+ params.extend(time_window)
+
+ query += " ORDER BY timestamp_s DESC"
+
+ cursor.execute(query, params)
+ rows = cursor.fetchall()
+
+ return [
+ {
+ "id": row["id"],
+ "relation_type": row["relation_type"],
+ "subject_id": row["subject_id"],
+ "object_id": row["object_id"],
+ "confidence": row["confidence"],
+ "timestamp_s": row["timestamp_s"],
+ "evidence": json.loads(row["evidence"]) if row["evidence"] else None,
+ "notes": row["notes"],
+ }
+ for row in rows
+ ]
+
+ def get_recent_relations(self, limit: int = 50) -> list[dict[str, Any]]:
+ """Get most recent relations."""
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ SELECT * FROM relations
+ ORDER BY timestamp_s DESC
+ LIMIT ?
+ """,
+ (limit,),
+ )
+
+ rows = cursor.fetchall()
+ return [
+ {
+ "id": row["id"],
+ "relation_type": row["relation_type"],
+ "subject_id": row["subject_id"],
+ "object_id": row["object_id"],
+ "confidence": row["confidence"],
+ "timestamp_s": row["timestamp_s"],
+ "evidence": json.loads(row["evidence"]) if row["evidence"] else None,
+ "notes": row["notes"],
+ }
+ for row in rows
+ ]
+
+ # ==================== Distance Operations (Graph 2) ====================
+
+ def add_distance(
+ self,
+ entity_a_id: str,
+ entity_b_id: str,
+ distance_meters: float | None,
+ distance_category: str | None,
+ confidence: float,
+ timestamp_s: float,
+ method: str,
+ ) -> None:
+ """
+ Add distance measurement between two entities.
+
+ Args:
+ entity_a_id: First entity ID
+ entity_b_id: Second entity ID
+ distance_meters: Distance in meters (can be None if only categorical)
+ distance_category: Category (near/medium/far)
+ confidence: Confidence score
+ timestamp_s: Timestamp of measurement
+ method: Method used (vlm, depth_estimation, bbox)
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ # Normalize order to avoid duplicates (store alphabetically)
+ if entity_a_id > entity_b_id:
+ entity_a_id, entity_b_id = entity_b_id, entity_a_id
+
+ cursor.execute(
+ """
+ INSERT INTO distances (entity_a_id, entity_b_id, distance_meters, distance_category,
+ confidence, timestamp_s, method)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ entity_a_id,
+ entity_b_id,
+ distance_meters,
+ distance_category,
+ confidence,
+ timestamp_s,
+ method,
+ ),
+ )
+
+ conn.commit()
+ logger.debug(
+ f"Added distance: {entity_a_id} <--> {entity_b_id}: {distance_meters}m ({distance_category})"
+ )
+
+ def get_distance(
+ self,
+ entity_a_id: str,
+ entity_b_id: str,
+ ) -> dict[str, Any] | None:
+ """Get most recent distance between two entities.
+
+ Args:
+ entity_a_id: First entity ID
+ entity_b_id: Second entity ID
+
+ Returns:
+ Distance dict or None
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ # Normalize order
+ if entity_a_id > entity_b_id:
+ entity_a_id, entity_b_id = entity_b_id, entity_a_id
+
+ cursor.execute(
+ """
+ SELECT * FROM distances
+ WHERE entity_a_id = ? AND entity_b_id = ?
+ ORDER BY timestamp_s DESC
+ LIMIT 1
+ """,
+ (entity_a_id, entity_b_id),
+ )
+
+ row = cursor.fetchone()
+ if row is None:
+ return None
+
+ return {
+ "entity_a_id": row["entity_a_id"],
+ "entity_b_id": row["entity_b_id"],
+ "distance_meters": row["distance_meters"],
+ "distance_category": row["distance_category"],
+ "confidence": row["confidence"],
+ "timestamp_s": row["timestamp_s"],
+ "method": row["method"],
+ }
+
+ def get_distance_history(
+ self,
+ entity_a_id: str,
+ entity_b_id: str,
+ ) -> list[dict[str, Any]]:
+ """Get all distance measurements between two entities.
+
+ Args:
+ entity_a_id: First entity ID
+ entity_b_id: Second entity ID
+
+ Returns:
+ List of distance dicts, most recent first
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ # Normalize order
+ if entity_a_id > entity_b_id:
+ entity_a_id, entity_b_id = entity_b_id, entity_a_id
+
+ cursor.execute(
+ """
+ SELECT * FROM distances
+ WHERE entity_a_id = ? AND entity_b_id = ?
+ ORDER BY timestamp_s DESC
+ """,
+ (entity_a_id, entity_b_id),
+ )
+
+ return [
+ {
+ "entity_a_id": row["entity_a_id"],
+ "entity_b_id": row["entity_b_id"],
+ "distance_meters": row["distance_meters"],
+ "distance_category": row["distance_category"],
+ "confidence": row["confidence"],
+ "timestamp_s": row["timestamp_s"],
+ "method": row["method"],
+ }
+ for row in cursor.fetchall()
+ ]
+
+ def get_nearby_entities(
+ self,
+ entity_id: str,
+ max_distance: float,
+ latest_only: bool = True,
+ ) -> list[dict[str, Any]]:
+ """
+ Find entities within a distance threshold.
+
+ Args:
+ entity_id: Reference entity ID
+ max_distance: Maximum distance in meters
+ latest_only: If True, use only latest measurements
+
+ Returns:
+ List of nearby entities with distances
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ if latest_only:
+ # Get latest distance for each pair
+ query = """
+ SELECT d.*, e.entity_type, e.descriptor
+ FROM distances d
+ INNER JOIN entities e ON (
+ CASE
+ WHEN d.entity_a_id = ? THEN e.entity_id = d.entity_b_id
+ WHEN d.entity_b_id = ? THEN e.entity_id = d.entity_a_id
+ END
+ )
+ WHERE (d.entity_a_id = ? OR d.entity_b_id = ?)
+ AND d.distance_meters IS NOT NULL
+ AND d.distance_meters <= ?
+ AND d.id IN (
+ SELECT MAX(id) FROM distances
+ WHERE (entity_a_id = d.entity_a_id AND entity_b_id = d.entity_b_id)
+ GROUP BY entity_a_id, entity_b_id
+ )
+ ORDER BY d.distance_meters ASC
+ """
+ cursor.execute(query, (entity_id, entity_id, entity_id, entity_id, max_distance))
+ else:
+ query = """
+ SELECT d.*, e.entity_type, e.descriptor
+ FROM distances d
+ INNER JOIN entities e ON (
+ CASE
+ WHEN d.entity_a_id = ? THEN e.entity_id = d.entity_b_id
+ WHEN d.entity_b_id = ? THEN e.entity_id = d.entity_a_id
+ END
+ )
+ WHERE (d.entity_a_id = ? OR d.entity_b_id = ?)
+ AND d.distance_meters IS NOT NULL
+ AND d.distance_meters <= ?
+ ORDER BY d.distance_meters ASC
+ """
+ cursor.execute(query, (entity_id, entity_id, entity_id, entity_id, max_distance))
+
+ rows = cursor.fetchall()
+ return [
+ {
+ "entity_id": row["entity_b_id"]
+ if row["entity_a_id"] == entity_id
+ else row["entity_a_id"],
+ "entity_type": row["entity_type"],
+ "descriptor": row["descriptor"],
+ "distance_meters": row["distance_meters"],
+ "distance_category": row["distance_category"],
+ "confidence": row["confidence"],
+ "timestamp_s": row["timestamp_s"],
+ }
+ for row in rows
+ ]
+
+ def add_semantic_relation(
+ self,
+ relation_type: str,
+ entity_a_id: str,
+ entity_b_id: str,
+ confidence: float,
+ learned_from: str,
+ timestamp_s: float,
+ ) -> None:
+ """
+ Add or update a semantic relation.
+
+ Args:
+ relation_type: Relation type (goes_with, opposite_of, part_of, used_for)
+ entity_a_id: First entity ID
+ entity_b_id: Second entity ID
+ confidence: Confidence score
+ learned_from: Source (llm, knowledge_base, observation)
+ timestamp_s: Timestamp when learned
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ # Normalize order for symmetric relations
+ if entity_a_id > entity_b_id:
+ entity_a_id, entity_b_id = entity_b_id, entity_a_id
+
+ # Check if relation exists
+ cursor.execute(
+ """
+ SELECT id, observation_count, confidence FROM semantic_relations
+ WHERE relation_type = ? AND entity_a_id = ? AND entity_b_id = ?
+ """,
+ (relation_type, entity_a_id, entity_b_id),
+ )
+
+ existing = cursor.fetchone()
+
+ if existing:
+ # Update existing relation (increase confidence, increment count)
+ new_count = existing["observation_count"] + 1
+ new_confidence = min(
+ 1.0, existing["confidence"] + 0.1
+ ) # Increase confidence with observations
+
+ cursor.execute(
+ """
+ UPDATE semantic_relations
+ SET last_observed_ts = ?,
+ observation_count = ?,
+ confidence = ?
+ WHERE id = ?
+ """,
+ (timestamp_s, new_count, new_confidence, existing["id"]),
+ )
+ else:
+ # Insert new relation
+ cursor.execute(
+ """
+ INSERT INTO semantic_relations
+ (relation_type, entity_a_id, entity_b_id, confidence, learned_from,
+ first_observed_ts, last_observed_ts, observation_count)
+ VALUES (?, ?, ?, ?, ?, ?, ?, 1)
+ """,
+ (
+ relation_type,
+ entity_a_id,
+ entity_b_id,
+ confidence,
+ learned_from,
+ timestamp_s,
+ timestamp_s,
+ ),
+ )
+
+ conn.commit()
+ logger.debug(f"Added semantic relation: {entity_a_id} --{relation_type}--> {entity_b_id}")
+
+ def get_semantic_relations(
+ self,
+ entity_id: str | None = None,
+ relation_type: str | None = None,
+ ) -> list[dict[str, Any]]:
+ """
+ Get semantic relations, optionally filtered.
+
+ Args:
+ entity_id: Optional filter by entity
+ relation_type: Optional filter by relation type
+
+ Returns:
+ List of semantic relation dicts
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ query = "SELECT * FROM semantic_relations WHERE 1=1"
+ params: list[Any] = []
+
+ if entity_id:
+ query += " AND (entity_a_id = ? OR entity_b_id = ?)"
+ params.extend([entity_id, entity_id])
+
+ if relation_type:
+ query += " AND relation_type = ?"
+ params.append(relation_type)
+
+ query += " ORDER BY confidence DESC, observation_count DESC"
+
+ cursor.execute(query, params)
+ rows = cursor.fetchall()
+
+ return [
+ {
+ "id": row["id"],
+ "relation_type": row["relation_type"],
+ "entity_a_id": row["entity_a_id"],
+ "entity_b_id": row["entity_b_id"],
+ "confidence": row["confidence"],
+ "learned_from": row["learned_from"],
+ "first_observed_ts": row["first_observed_ts"],
+ "last_observed_ts": row["last_observed_ts"],
+ "observation_count": row["observation_count"],
+ }
+ for row in rows
+ ]
+
+ # querying
+
+ def get_entity_neighborhood(
+ self,
+ entity_id: str,
+ max_hops: int = 2,
+ include_distances: bool = True,
+ include_semantics: bool = True,
+ ) -> dict[str, Any]:
+ """
+ Get entity neighborhood (BFS traversal).
+
+ Args:
+ entity_id: Starting entity ID
+ max_hops: Maximum number of hops to traverse
+ include_distances: Include distance graph
+ include_semantics: Include semantic graph
+
+ Returns:
+ Dict with entities, relations, distances, and semantics
+ """
+ visited_entities = {entity_id}
+ current_level = {entity_id}
+ all_relations = []
+ all_distances = []
+ all_semantics = []
+
+ for _ in range(max_hops):
+ next_level = set()
+
+ for ent_id in current_level:
+ # Get relations
+ relations = self.get_relations_for_entity(ent_id)
+ all_relations.extend(relations)
+
+ for rel in relations:
+ other_id = (
+ rel["object_id"] if rel["subject_id"] == ent_id else rel["subject_id"]
+ )
+ if other_id not in visited_entities:
+ next_level.add(other_id)
+ visited_entities.add(other_id)
+
+ # Get distances
+ if include_distances:
+ distances = self.get_nearby_entities(ent_id, max_distance=10.0)
+ all_distances.extend(distances)
+ for dist in distances:
+ other_id = dist["entity_id"]
+ if other_id not in visited_entities:
+ next_level.add(other_id)
+ visited_entities.add(other_id)
+
+ # Get semantic relations
+ if include_semantics:
+ semantics = self.get_semantic_relations(entity_id=ent_id)
+ all_semantics.extend(semantics)
+ for sem in semantics:
+ other_id = (
+ sem["entity_b_id"]
+ if sem["entity_a_id"] == ent_id
+ else sem["entity_a_id"]
+ )
+ if other_id not in visited_entities:
+ next_level.add(other_id)
+ visited_entities.add(other_id)
+
+ current_level = next_level
+ if not current_level:
+ break
+
+ # Get all entity details
+ entities = [self.get_entity(ent_id) for ent_id in visited_entities]
+ entities = [e for e in entities if e is not None]
+
+ return {
+ "center_entity": entity_id,
+ "entities": entities,
+ "relations": all_relations,
+ "distances": all_distances,
+ "semantic_relations": all_semantics,
+ "num_hops": max_hops,
+ }
+
+ def get_stats(self) -> dict[str, Any]:
+ """Get database statistics."""
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute("SELECT COUNT(*) as count FROM entities")
+ entity_count = cursor.fetchone()["count"]
+
+ cursor.execute("SELECT COUNT(*) as count FROM relations")
+ relation_count = cursor.fetchone()["count"]
+
+ cursor.execute("SELECT COUNT(*) as count FROM distances")
+ distance_count = cursor.fetchone()["count"]
+
+ cursor.execute("SELECT COUNT(*) as count FROM semantic_relations")
+ semantic_count = cursor.fetchone()["count"]
+
+ return {
+ "entities": entity_count,
+ "relations": relation_count,
+ "distances": distance_count,
+ "semantic_relations": semantic_count,
+ }
+
+ def get_summary(self, recent_relations_limit: int = 5) -> dict[str, Any]:
+ """Get stats, all entities, and recent relations."""
+ return {
+ "stats": self.get_stats(),
+ "entities": self.get_all_entities(),
+ "recent_relations": self.get_recent_relations(limit=recent_relations_limit),
+ }
+
+ def save_window_data(self, parsed: dict[str, Any], timestamp_s: float) -> None:
+ """Save parsed window data (entities and relations) to the graph database."""
+ try:
+ # Save new entities
+ for entity in parsed.get("new_entities", []):
+ self.upsert_entity(
+ entity_id=entity["id"],
+ entity_type=entity["type"],
+ descriptor=entity.get("descriptor", "unknown"),
+ timestamp_s=timestamp_s,
+ )
+
+ # Save existing entities (update last_seen)
+ for entity in parsed.get("entities_present", []):
+ if isinstance(entity, dict) and "id" in entity:
+ descriptor = entity.get("descriptor")
+ if descriptor:
+ self.upsert_entity(
+ entity_id=entity["id"],
+ entity_type=entity.get("type", "unknown"),
+ descriptor=descriptor,
+ timestamp_s=timestamp_s,
+ )
+ else:
+ existing = self.get_entity(entity["id"])
+ if existing:
+ self.upsert_entity(
+ entity_id=entity["id"],
+ entity_type=existing["entity_type"],
+ descriptor=existing["descriptor"],
+ timestamp_s=timestamp_s,
+ )
+
+ # Save relations
+ for relation in parsed.get("relations", []):
+ subject_id = (
+ relation["subject"].split("|")[0]
+ if "|" in relation["subject"]
+ else relation["subject"]
+ )
+ object_id = (
+ relation["object"].split("|")[0]
+ if "|" in relation["object"]
+ else relation["object"]
+ )
+
+ self.add_relation(
+ relation_type=relation["type"],
+ subject_id=subject_id,
+ object_id=object_id,
+ confidence=relation.get("confidence", 1.0),
+ timestamp_s=timestamp_s,
+ evidence=relation.get("evidence"),
+ notes=relation.get("notes"),
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to save window data to graph DB: {e}", exc_info=True)
+
+ def estimate_and_save_distances(
+ self,
+ parsed: dict[str, Any],
+ frame_image: "Image",
+ vlm: "VlModel",
+ timestamp_s: float,
+ max_distance_pairs: int = 5,
+ ) -> None:
+ """Estimate distances between entities using VLM and save to database.
+
+ Args:
+ parsed: Parsed window data containing entities
+ frame_image: Frame image to analyze
+ vlm: VLM instance for distance estimation
+ timestamp_s: Timestamp for the distance measurements
+ max_distance_pairs: Maximum number of entity pairs to estimate
+ """
+ if not frame_image:
+ return
+
+ # Import here to avoid circular dependency
+ from . import temporal_utils as tu
+
+ # Collect entities with descriptors
+ # new_entities have descriptors from VLM
+ enriched_entities = []
+ for entity in parsed.get("new_entities", []):
+ if isinstance(entity, dict) and "id" in entity:
+ enriched_entities.append(
+ {"id": entity["id"], "descriptor": entity.get("descriptor", "unknown")}
+ )
+
+ # entities_present only have IDs - need to fetch descriptors from DB
+ for entity in parsed.get("entities_present", []):
+ if isinstance(entity, dict) and "id" in entity:
+ entity_id = entity["id"]
+ # Fetch descriptor from DB
+ db_entity = self.get_entity(entity_id)
+ if db_entity:
+ enriched_entities.append(
+ {"id": entity_id, "descriptor": db_entity.get("descriptor", "unknown")}
+ )
+
+ if len(enriched_entities) < 2:
+ return
+
+ # Generate pairs without existing distances
+ pairs = [
+ (enriched_entities[i], enriched_entities[j])
+ for i in range(len(enriched_entities))
+ for j in range(i + 1, len(enriched_entities))
+ if not self.get_distance(enriched_entities[i]["id"], enriched_entities[j]["id"])
+ ][:max_distance_pairs]
+
+ if not pairs:
+ return
+
+ try:
+ response = vlm.query(frame_image, tu.build_batch_distance_estimation_prompt(pairs))
+ for r in tu.parse_batch_distance_response(response, pairs):
+ if r["category"] in ("near", "medium", "far"):
+ self.add_distance(
+ entity_a_id=r["entity_a_id"],
+ entity_b_id=r["entity_b_id"],
+ distance_meters=r.get("distance_m"),
+ distance_category=r["category"],
+ confidence=r.get("confidence", 0.5),
+ timestamp_s=timestamp_s,
+ method="vlm",
+ )
+ except Exception as e:
+ logger.warning(f"Failed to estimate distances: {e}", exc_info=True)
+
+ def commit(self) -> None:
+ """Commit all pending transactions and ensure data is flushed to disk."""
+ if hasattr(self._local, "conn"):
+ conn = self._local.conn
+ conn.commit()
+ # Force checkpoint to ensure WAL data is written to main database file
+ try:
+ conn.execute("PRAGMA wal_checkpoint(FULL)")
+ except Exception:
+ pass # Ignore if WAL is not enabled
+
+ def close(self) -> None:
+ """Close database connection."""
+ if hasattr(self._local, "conn"):
+ self._local.conn.close()
+ del self._local.conn
diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.md b/dimos/perception/experimental/temporal_memory/temporal_memory.md
new file mode 100644
index 0000000000..0eaa3df893
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_memory.md
@@ -0,0 +1,39 @@
+Dimensional Temporal Memory is a lightweight Video RAG pipeline for building
+entity-centric memory over live or replayed video streams. It uses a VLM to
+extract evidence in sliding windows, tracks entities across time, maintains a
+rolling summary, and persists relations in a compact graph for query-time context.
+
+How It Works
+1) Sample frames at a target FPS and analyze them in sliding windows.
+2) Extract dense evidence with a VLM (caption + entities + relations).
+3) Update a rolling summary for global context.
+4) Persist per-window evidence and the entity graph for fast queries.
+
+Setup
+- Add your OpenAI key to `.env`:
+ `OPENAI_API_KEY=...`
+- Install dependencies (recommended set from repo install guide):
+ `uv sync --extra dev --extra cpu --extra sim --extra drone`
+
+`uv sync` installs the locked dependency set from `uv.lock` to match the repo's
+known-good environment. `uv pip install ...` behaves like pip (ad-hoc installs)
+and can drift from the lockfile.
+
+Quickstart
+- Run Temporal Memory on a replay:
+ `dimos --replay run unitree-go2-temporal-memory`
+- In another terminal, open a chat session:
+ `humancli`
+
+Artifacts
+By default, artifacts are written under `assets/temporal_memory`:
+- `evidence.jsonl` (window evidence: captions, entities, relations)
+- `state.json` (rolling summary + roster state)
+- `entities.json` (current entity roster)
+- `frames_index.jsonl` (timestamps for saved frames; written on stop)
+- `entity_graph.db` (SQLite graph of relations/distances)
+
+Notes
+- Evidence is extracted in sliding windows; queries can reference recent or past entities.
+- Distance estimation can run in the background to enrich graph relations.
+- Change the output location via `TemporalMemoryConfig(output_dir=...)`.
diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py
new file mode 100644
index 0000000000..29d4ecf3d9
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py
@@ -0,0 +1,665 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Temporal Memory module for creating entity-based temporal understanding of video streams.
+
+This module implements a sophisticated temporal memory system inspired by VideoRAG,
+using VLM (Vision-Language Model) API calls to maintain entity rosters, rolling summaries,
+and temporal relationships across video frames.
+"""
+
+from collections import deque
+from dataclasses import dataclass
+import json
+import os
+from pathlib import Path
+import threading
+import time
+from typing import Any
+
+from reactivex import Subject, interval
+from reactivex.disposable import Disposable
+
+from dimos.agents import skill
+from dimos.core import In, rpc
+from dimos.core.module import ModuleConfig
+from dimos.core.skill_module import SkillModule
+from dimos.models.vl.base import VlModel
+from dimos.msgs.sensor_msgs import Image
+from dimos.msgs.sensor_msgs.Image import sharpness_barrier
+
+from . import temporal_utils as tu
+from .clip_filter import (
+ CLIP_AVAILABLE,
+ adaptive_keyframes,
+)
+
+try:
+ from .clip_filter import CLIPFrameFilter
+except ImportError:
+ CLIPFrameFilter = type(None) # type: ignore[misc,assignment]
+from dimos.utils.logging_config import setup_logger
+
+from .entity_graph_db import EntityGraphDB
+
+logger = setup_logger()
+
+# Constants
+MAX_RECENT_WINDOWS = 50 # Max recent windows to keep in memory
+
+
+@dataclass
+class Frame:
+ frame_index: int
+ timestamp_s: float
+ image: Image
+
+
+@dataclass
+class TemporalMemoryConfig(ModuleConfig):
+ # Frame processing
+ fps: float = 1.0
+ window_s: float = 2.0
+ stride_s: float = 2.0
+ summary_interval_s: float = 10.0
+ max_frames_per_window: int = 3
+ frame_buffer_size: int = 50
+
+ # Output
+ output_dir: str | Path | None = "assets/temporal_memory"
+
+ # VLM parameters
+ max_tokens: int = 900
+ temperature: float = 0.2
+
+ # Frame filtering
+ use_clip_filtering: bool = True
+ clip_model: str = "ViT-B/32"
+ stale_scene_threshold: float = 5.0
+
+ # Graph database
+ persistent_memory: bool = True # Keep graph across sessions
+ clear_memory_on_start: bool = False # Wipe DB on startup
+ enable_distance_estimation: bool = True # Estimate entity distances
+ max_distance_pairs: int = 5 # Max entity pairs per window
+
+ # Graph context
+ max_relations_per_entity: int = 10 # Max relations in query context
+ nearby_distance_meters: float = 5.0 # "Nearby" threshold
+
+
+class TemporalMemory(SkillModule):
+ """
+ builds temporal understanding of video streams using vlms.
+
+ processes frames reactively, maintains entity rosters, tracks temporal
+ relationships, builds rolling summaries. responds to queries about current
+ state and recent events.
+ """
+
+ color_image: In[Image]
+
+ def __init__(
+ self, vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None
+ ) -> None:
+ super().__init__()
+
+ self._vlm = vlm # Can be None for blueprint usage
+ self.config: TemporalMemoryConfig = config or TemporalMemoryConfig()
+
+ # single lock protects all state
+ self._state_lock = threading.Lock()
+ self._stopped = False
+
+ # protected state
+ self._state = tu.default_state()
+ self._state["next_summary_at_s"] = float(self.config.summary_interval_s)
+ self._frame_buffer: deque[Frame] = deque(maxlen=self.config.frame_buffer_size)
+ self._recent_windows: deque[dict[str, Any]] = deque(maxlen=MAX_RECENT_WINDOWS)
+ self._frame_count = 0
+ # Start at -inf so first analysis passes stride_s check regardless of elapsed time
+ self._last_analysis_time = -float("inf")
+ self._video_start_wall_time: float | None = None
+
+ # Track background distance estimation threads
+ self._distance_threads: list[threading.Thread] = []
+
+ # clip filter - use instance state to avoid mutating shared config
+ self._clip_filter: CLIPFrameFilter | None = None
+ self._use_clip_filtering = self.config.use_clip_filtering
+ if self._use_clip_filtering and CLIP_AVAILABLE:
+ try:
+ self._clip_filter = CLIPFrameFilter(model_name=self.config.clip_model)
+ logger.info("clip filtering enabled")
+ except Exception as e:
+ logger.warning(f"clip init failed: {e}")
+ self._use_clip_filtering = False
+ elif self._use_clip_filtering:
+ logger.warning("clip not available")
+ self._use_clip_filtering = False
+
+ # output directory
+ self._graph_db: EntityGraphDB | None
+ if self.config.output_dir:
+ self._output_path = Path(self.config.output_dir)
+ self._output_path.mkdir(parents=True, exist_ok=True)
+ self._evidence_file = self._output_path / "evidence.jsonl"
+ self._state_file = self._output_path / "state.json"
+ self._entities_file = self._output_path / "entities.json"
+ self._frames_index_file = self._output_path / "frames_index.jsonl"
+
+ db_path = self._output_path / "entity_graph.db"
+ if not self.config.persistent_memory or self.config.clear_memory_on_start:
+ if db_path.exists():
+ db_path.unlink()
+ reason = (
+ "non-persistent mode"
+ if not self.config.persistent_memory
+ else "clear_memory_on_start=True"
+ )
+ logger.info(f"Deleted existing database: {reason}")
+
+ self._graph_db = EntityGraphDB(db_path=db_path)
+
+ logger.info(f"artifacts save to: {self._output_path}")
+ else:
+ self._graph_db = None
+
+ logger.info(
+ f"temporalmemory init: fps={self.config.fps}, "
+ f"window={self.config.window_s}s, stride={self.config.stride_s}s"
+ )
+
+ @property
+ def vlm(self) -> VlModel:
+ """Get or create VLM instance lazily."""
+ if self._vlm is None:
+ from dimos.models.vl.openai import OpenAIVlModel
+
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError(
+ "OPENAI_API_KEY environment variable not set. "
+ "Either set it or pass a vlm instance to TemporalMemory constructor."
+ )
+ self._vlm = OpenAIVlModel(api_key=api_key)
+ logger.info("Created OpenAIVlModel from OPENAI_API_KEY environment variable")
+ return self._vlm
+
+ @rpc
+ def start(self) -> None:
+ super().start()
+
+ with self._state_lock:
+ self._stopped = False
+ if self._video_start_wall_time is None:
+ self._video_start_wall_time = time.time()
+
+ def on_frame(image: Image) -> None:
+ with self._state_lock:
+ video_start = self._video_start_wall_time
+ if video_start is None:
+ return # Not started yet
+ if image.ts is not None:
+ timestamp_s = image.ts - video_start
+ else:
+ timestamp_s = time.time() - video_start
+
+ frame = Frame(
+ frame_index=self._frame_count,
+ timestamp_s=timestamp_s,
+ image=image,
+ )
+ self._frame_buffer.append(frame)
+ self._frame_count += 1
+
+ frame_subject: Subject[Image] = Subject()
+ self._disposables.add(
+ frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(on_frame)
+ )
+
+ unsub_image = self.color_image.subscribe(frame_subject.on_next)
+ self._disposables.add(Disposable(unsub_image))
+
+ # Schedule window analysis every stride_s seconds
+ self._disposables.add(
+ interval(self.config.stride_s).subscribe(lambda _: self._analyze_window())
+ )
+
+ logger.info("temporalmemory started")
+
+ @rpc
+ def stop(self) -> None:
+ # Save state before clearing (bypass _stopped check by saving directly)
+ if self.config.output_dir:
+ try:
+ with self._state_lock:
+ state_copy = self._state.copy()
+ entity_roster = list(self._state.get("entity_roster", []))
+ with open(self._state_file, "w") as f:
+ json.dump(state_copy, f, indent=2, ensure_ascii=False)
+ logger.info(f"saved state to {self._state_file}")
+ with open(self._entities_file, "w") as f:
+ json.dump(entity_roster, f, indent=2, ensure_ascii=False)
+ logger.info(f"saved {len(entity_roster)} entities")
+ except Exception as e:
+ logger.error(f"save failed during stop: {e}", exc_info=True)
+
+ self.save_frames_index()
+ with self._state_lock:
+ self._stopped = True
+
+ # Wait for background distance estimation threads to complete before closing DB
+ if self._distance_threads:
+ logger.info(f"Waiting for {len(self._distance_threads)} distance estimation threads...")
+ for thread in self._distance_threads:
+ thread.join(timeout=10.0) # Wait max 10s per thread
+ self._distance_threads.clear()
+
+ if self._graph_db:
+ db_path = self._graph_db.db_path
+ self._graph_db.commit() # save all pending transactions
+ self._graph_db.close()
+ self._graph_db = None
+
+ if not self.config.persistent_memory and db_path.exists():
+ db_path.unlink()
+ logger.info("Deleted non-persistent database")
+
+ if self._clip_filter:
+ self._clip_filter.close()
+ self._clip_filter = None
+
+ with self._state_lock:
+ self._frame_buffer.clear()
+ self._recent_windows.clear()
+ self._state = tu.default_state()
+
+ super().stop()
+
+ # Stop all stream transports to clean up LCM/shared memory threads
+ # Note: We use public stream.transport API and rely on transport.stop() to clean up
+ for stream in list(self.inputs.values()) + list(self.outputs.values()):
+ if stream.transport is not None and hasattr(stream.transport, "stop"):
+ try:
+ stream.transport.stop()
+ except Exception as e:
+ logger.warning(f"Failed to stop stream transport: {e}")
+
+ logger.info("temporalmemory stopped")
+
+ def _get_window_frames(self) -> tuple[list[Frame], dict[str, Any]] | None:
+ """Extract window frames from buffer with guards."""
+ with self._state_lock:
+ if not self._frame_buffer:
+ return None
+ current_time = self._frame_buffer[-1].timestamp_s
+ if current_time - self._last_analysis_time < self.config.stride_s:
+ return None
+ frames_needed = max(1, int(self.config.fps * self.config.window_s))
+ if len(self._frame_buffer) < frames_needed:
+ return None
+ window_frames = list(self._frame_buffer)[-frames_needed:]
+ state_snapshot = self._state.copy()
+ return window_frames, state_snapshot
+
+ def _query_vlm_for_window(
+ self,
+ window_frames: list[Frame],
+ state_snapshot: dict[str, Any],
+ w_start: float,
+ w_end: float,
+ ) -> str | None:
+ """Query VLM for window analysis."""
+ query = tu.build_window_prompt(
+ w_start=w_start, w_end=w_end, frame_count=len(window_frames), state=state_snapshot
+ )
+ try:
+ fmt = tu.get_structured_output_format()
+ if len(window_frames) > 1:
+ responses = self.vlm.query_batch(
+ [f.image for f in window_frames], query, response_format=fmt
+ )
+ return responses[0] if responses else ""
+ else:
+ return self.vlm.query(window_frames[0].image, query, response_format=fmt)
+ except Exception as e:
+ logger.error(f"vlm query failed [{w_start:.1f}-{w_end:.1f}s]: {e}", exc_info=True)
+ return None
+
+ def _save_window_artifacts(self, parsed: dict[str, Any], w_end: float) -> None:
+ """Save window data to graph DB and evidence file."""
+ if self._graph_db:
+ self._graph_db.save_window_data(parsed, w_end)
+ if self.config.output_dir:
+ self._append_evidence(parsed)
+
+ def _analyze_window(self) -> None:
+ """Analyze a temporal window of frames using VLM."""
+ # Extract window frames with guards
+ result = self._get_window_frames()
+ if result is None:
+ return
+ window_frames, state_snapshot = result
+ w_start, w_end = window_frames[0].timestamp_s, window_frames[-1].timestamp_s
+
+ # Skip if scene hasn't changed
+ if tu.is_scene_stale(window_frames, self.config.stale_scene_threshold):
+ with self._state_lock:
+ self._last_analysis_time = w_end
+ return
+
+ # Select diverse frames for analysis
+ window_frames = (
+ adaptive_keyframes( # TODO: unclear if clip vs. diverse vs. this solution is best
+ window_frames, max_frames=self.config.max_frames_per_window
+ )
+ )
+ logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames")
+
+ # Query VLM and parse response
+ response_text = self._query_vlm_for_window(window_frames, state_snapshot, w_start, w_end)
+ if response_text is None:
+ with self._state_lock:
+ self._last_analysis_time = w_end
+ return
+
+ parsed = tu.parse_window_response(response_text, w_start, w_end, len(window_frames))
+ if "_error" in parsed:
+ logger.error(f"parse error: {parsed['_error']}")
+ # else:
+ # logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}")
+
+ # Start distance estimation in background
+ if self._graph_db and window_frames and self.config.enable_distance_estimation:
+ mid_frame = window_frames[len(window_frames) // 2]
+ if mid_frame.image:
+ thread = threading.Thread(
+ target=self._graph_db.estimate_and_save_distances,
+ args=(parsed, mid_frame.image, self.vlm, w_end, self.config.max_distance_pairs),
+ daemon=True,
+ )
+ thread.start()
+ self._distance_threads = [t for t in self._distance_threads if t.is_alive()]
+ self._distance_threads.append(thread)
+
+ # Update temporal state
+ with self._state_lock:
+ needs_summary = tu.update_state_from_window(
+ self._state, parsed, w_end, self.config.summary_interval_s
+ )
+ self._recent_windows.append(parsed)
+ self._last_analysis_time = w_end
+
+ # Save artifacts
+ self._save_window_artifacts(parsed, w_end)
+
+ # Trigger summary update if needed
+ if needs_summary:
+ logger.info(f"updating summary at t≈{w_end:.1f}s")
+ self._update_rolling_summary(w_end)
+
+ # Periodic state saves
+ with self._state_lock:
+ window_count = len(self._recent_windows)
+ if window_count % 10 == 0:
+ self.save_state()
+ self.save_entities()
+
+ def _update_rolling_summary(self, w_end: float) -> None:
+ with self._state_lock:
+ if self._stopped:
+ return
+ rolling_summary = str(self._state.get("rolling_summary", ""))
+ chunk_buffer = list(self._state.get("chunk_buffer", []))
+ latest_frame = self._frame_buffer[-1].image if self._frame_buffer else None
+
+ if not chunk_buffer or not latest_frame:
+ return
+
+ prompt = tu.build_summary_prompt(
+ rolling_summary=rolling_summary, chunk_windows=chunk_buffer
+ )
+
+ try:
+ summary_text = self.vlm.query(latest_frame, prompt)
+ if summary_text and summary_text.strip():
+ with self._state_lock:
+ if self._stopped:
+ return
+ tu.apply_summary_update(
+ self._state, summary_text, w_end, self.config.summary_interval_s
+ )
+ logger.info(f"updated summary: {summary_text[:100]}...")
+ if self.config.output_dir and not self._stopped:
+ self.save_state()
+ self.save_entities()
+ except Exception as e:
+ logger.error(f"summary update failed: {e}", exc_info=True)
+
+ @skill()
+ def query(self, question: str) -> str:
+ """Answer a question about the video stream using temporal memory and graph knowledge.
+
+ This skill analyzes the current video stream and temporal memory state
+ to answer questions about what is happening, what entities are present,
+ recent events, spatial relationships, and conceptual knowledge.
+
+ The system automatically accesses three knowledge graphs:
+ - Interactions: relationships between entities (holds, looks_at, talks_to)
+ - Spatial: distance and proximity information
+ - Semantic: conceptual relationships (goes_with, used_for, etc.)
+
+ Example:
+ query("What entities are currently visible?")
+ query("What did I do last week?")
+ query("Where did I leave my keys?")
+ query("What objects are near the person?")
+
+ Args:
+ question (str): The question to ask about the video stream.
+ Examples: "What entities are visible?", "What happened recently?",
+ "Is there a person in the scene?", "What am I holding?"
+
+ Returns:
+ str: Answer based on temporal memory, graph knowledge, and current frame.
+ """
+ # read state
+ with self._state_lock:
+ entity_roster = list(self._state.get("entity_roster", []))
+ rolling_summary = str(self._state.get("rolling_summary", ""))
+ last_present = list(self._state.get("last_present", []))
+ recent_windows = list(self._recent_windows)
+ if self._frame_buffer:
+ latest_frame = self._frame_buffer[-1].image
+ current_video_time_s = self._frame_buffer[-1].timestamp_s
+ else:
+ latest_frame = None
+ current_video_time_s = 0.0
+
+ if not latest_frame:
+ return "no frames available"
+
+ # build context from temporal state
+ # Include entities from last_present and recent windows (both entities_present and new_entities)
+ currently_present = {e["id"] for e in last_present if isinstance(e, dict) and "id" in e}
+ for window in recent_windows[-3:]:
+ # Add entities that were present
+ for entity in window.get("entities_present", []):
+ if isinstance(entity, dict) and isinstance(entity.get("id"), str):
+ currently_present.add(entity["id"])
+ # Also include newly detected entities (they're present now)
+ for entity in window.get("new_entities", []):
+ if isinstance(entity, dict) and isinstance(entity.get("id"), str):
+ currently_present.add(entity["id"])
+
+ context = {
+ "entity_roster": entity_roster,
+ "rolling_summary": rolling_summary,
+ "currently_present_entities": sorted(currently_present),
+ "recent_windows_count": len(recent_windows),
+ "timestamp": time.time(),
+ }
+
+ # enhance context with graph database knowledge
+ if self._graph_db:
+ # Extract time window from question using VLM
+ time_window_s = tu.extract_time_window(question, self.vlm, latest_frame)
+
+ # Query graph for ALL entities in roster (not just currently present)
+ # This allows queries about entities that disappeared or were seen in the past
+ all_entity_ids = [e["id"] for e in entity_roster if isinstance(e, dict) and "id" in e]
+
+ if all_entity_ids:
+ graph_context = tu.build_graph_context(
+ graph_db=self._graph_db,
+ entity_ids=all_entity_ids,
+ time_window_s=time_window_s,
+ max_relations_per_entity=self.config.max_relations_per_entity,
+ nearby_distance_meters=self.config.nearby_distance_meters,
+ current_video_time_s=current_video_time_s,
+ )
+ context["graph_knowledge"] = graph_context
+
+ # build query prompt using temporal utils
+ prompt = tu.build_query_prompt(question=question, context=context)
+
+ # query vlm (slow, outside lock)
+ try:
+ answer_text = self.vlm.query(latest_frame, prompt)
+ return answer_text.strip()
+ except Exception as e:
+ logger.error(f"query failed: {e}", exc_info=True)
+ return f"error: {e}"
+
+ @rpc
+ def clear_history(self) -> bool:
+ """Clear temporal memory state."""
+ try:
+ with self._state_lock:
+ self._state = tu.default_state()
+ self._state["next_summary_at_s"] = float(self.config.summary_interval_s)
+ self._recent_windows.clear()
+ logger.info("cleared history")
+ return True
+ except Exception as e:
+ logger.error(f"clear_history failed: {e}", exc_info=True)
+ return False
+
+ @rpc
+ def get_state(self) -> dict[str, Any]:
+ with self._state_lock:
+ return {
+ "entity_count": len(self._state.get("entity_roster", [])),
+ "entities": list(self._state.get("entity_roster", [])),
+ "rolling_summary": str(self._state.get("rolling_summary", "")),
+ "frame_count": self._frame_count,
+ "buffer_size": len(self._frame_buffer),
+ "recent_windows": len(self._recent_windows),
+ "currently_present": list(self._state.get("last_present", [])),
+ }
+
+ @rpc
+ def get_entity_roster(self) -> list[dict[str, Any]]:
+ with self._state_lock:
+ return list(self._state.get("entity_roster", []))
+
+ @rpc
+ def get_rolling_summary(self) -> str:
+ with self._state_lock:
+ return str(self._state.get("rolling_summary", ""))
+
+ @rpc
+ def get_graph_db_stats(self) -> dict[str, Any]:
+ """Get statistics and sample data from the graph database.
+
+ Returns empty structures when no database is available (no-error pattern).
+ """
+ if not self._graph_db:
+ return {"stats": {}, "entities": [], "recent_relations": []}
+ return self._graph_db.get_summary()
+
+ @rpc
+ def save_state(self) -> bool:
+ if not self.config.output_dir:
+ return False
+ try:
+ with self._state_lock:
+ # Don't save if stopped (state has been cleared)
+ if self._stopped:
+ return False
+ state_copy = self._state.copy()
+ with open(self._state_file, "w") as f:
+ json.dump(state_copy, f, indent=2, ensure_ascii=False)
+ logger.info(f"saved state to {self._state_file}")
+ return True
+ except Exception as e:
+ logger.error(f"save state failed: {e}", exc_info=True)
+ return False
+
+ def _append_evidence(self, evidence: dict[str, Any]) -> None:
+ try:
+ with open(self._evidence_file, "a") as f:
+ f.write(json.dumps(evidence, ensure_ascii=False) + "\n")
+ except Exception as e:
+ logger.error(f"append evidence failed: {e}", exc_info=True)
+
+ def save_entities(self) -> bool:
+ if not self.config.output_dir:
+ return False
+ try:
+ with self._state_lock:
+ # Don't save if stopped (state has been cleared)
+ if self._stopped:
+ return False
+ entity_roster = list(self._state.get("entity_roster", []))
+ with open(self._entities_file, "w") as f:
+ json.dump(entity_roster, f, indent=2, ensure_ascii=False)
+ logger.info(f"saved {len(entity_roster)} entities")
+ return True
+ except Exception as e:
+ logger.error(f"save entities failed: {e}", exc_info=True)
+ return False
+
+ def save_frames_index(self) -> bool:
+ if not self.config.output_dir:
+ return False
+ try:
+ with self._state_lock:
+ frames = list(self._frame_buffer)
+
+ frames_index = [
+ {
+ "frame_index": f.frame_index,
+ "timestamp_s": f.timestamp_s,
+ "timestamp": tu.format_timestamp(f.timestamp_s),
+ }
+ for f in frames
+ ]
+
+ if frames_index:
+ with open(self._frames_index_file, "w", encoding="utf-8") as f:
+ for rec in frames_index:
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
+ logger.info(f"saved {len(frames_index)} frames")
+ return True
+ except Exception as e:
+ logger.error(f"save frames failed: {e}", exc_info=True)
+ return False
+
+
+temporal_memory = TemporalMemory.blueprint
+
+__all__ = ["Frame", "TemporalMemory", "TemporalMemoryConfig", "temporal_memory"]
diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py
new file mode 100644
index 0000000000..611385630e
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py
@@ -0,0 +1,60 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Deployment helpers for TemporalMemory module.
+"""
+
+import os
+
+from dimos import spec
+from dimos.core import DimosCluster
+from dimos.models.vl.base import VlModel
+
+from .temporal_memory import TemporalMemory, TemporalMemoryConfig
+
+
+def deploy(
+ dimos: DimosCluster,
+ camera: spec.Camera,
+ vlm: VlModel | None = None,
+ config: TemporalMemoryConfig | None = None,
+) -> TemporalMemory:
+ """Deploy TemporalMemory with a camera.
+
+ Args:
+ dimos: Dimos cluster instance
+ camera: Camera module to connect to
+ vlm: Optional VLM instance (creates OpenAI VLM if None)
+ config: Optional temporal memory configuration
+ """
+ if vlm is None:
+ from dimos.models.vl.openai import OpenAIVlModel
+
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("OPENAI_API_KEY environment variable not set")
+ vlm = OpenAIVlModel(api_key=api_key)
+
+ temporal_memory = dimos.deploy(TemporalMemory, vlm=vlm, config=config) # type: ignore[attr-defined]
+
+ if camera.color_image.transport is None:
+ from dimos.core.transport import JpegShmTransport
+
+ transport = JpegShmTransport("/temporal_memory/color_image")
+ camera.color_image.transport = transport
+
+ temporal_memory.color_image.connect(camera.color_image)
+ temporal_memory.start()
+ return temporal_memory # type: ignore[return-value,no-any-return]
diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory_example.py b/dimos/perception/experimental/temporal_memory/temporal_memory_example.py
new file mode 100644
index 0000000000..8ba28bb174
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_memory_example.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Example usage of TemporalMemory module with a VLM.
+
+This example demonstrates how to:
+1. Deploy a camera module
+2. Deploy TemporalMemory with the camera
+3. Query the temporal memory about entities and events
+"""
+
+from pathlib import Path
+
+from dotenv import load_dotenv
+
+from dimos import core
+from dimos.hardware.sensors.camera.module import CameraModule
+from dimos.hardware.sensors.camera.webcam import Webcam
+
+from .temporal_memory import TemporalMemoryConfig
+from .temporal_memory_deploy import deploy
+
+# Load environment variables
+load_dotenv()
+
+
+def example_usage() -> None:
+ """Example of how to use TemporalMemory."""
+ # Initialize variables to None for cleanup
+ temporal_memory = None
+ camera = None
+ dimos = None
+
+ try:
+ # Create Dimos cluster
+ dimos = core.start(1)
+ # Deploy camera module
+ camera = dimos.deploy(CameraModule, hardware=lambda: Webcam(camera_index=0)) # type: ignore[attr-defined]
+ camera.start()
+
+ # Deploy temporal memory using the deploy function
+ output_dir = Path("./temporal_memory_output")
+ temporal_memory = deploy(
+ dimos,
+ camera,
+ vlm=None, # Will auto-create OpenAIVlModel if None
+ config=TemporalMemoryConfig(
+ fps=1.0, # Process 1 frame per second
+ window_s=2.0, # Analyze 2-second windows
+ stride_s=2.0, # New window every 2 seconds
+ summary_interval_s=10.0, # Update rolling summary every 10 seconds
+ max_frames_per_window=3, # Max 3 frames per window
+ output_dir=output_dir,
+ ),
+ )
+
+ print("TemporalMemory deployed and started!")
+ print(f"Artifacts will be saved to: {output_dir}")
+
+ # Let it run for a bit to build context
+ print("Building temporal context... (wait ~15 seconds)")
+ import time
+
+ time.sleep(20)
+
+ # Query the temporal memory
+ questions = [
+ "Are there any people in the scene?",
+ "Describe the main activity happening now",
+ "What has happened in the last few seconds?",
+ "What entities are currently visible?",
+ ]
+
+ for question in questions:
+ print(f"\nQuestion: {question}")
+ answer = temporal_memory.query(question)
+ print(f"Answer: {answer}")
+
+ # Get current state
+ state = temporal_memory.get_state()
+ print("\n=== Current State ===")
+ print(f"Entity count: {state['entity_count']}")
+ print(f"Frame count: {state['frame_count']}")
+ print(f"Rolling summary: {state['rolling_summary']}")
+ print(f"Entities: {state['entities']}")
+
+ # Get entity roster
+ entities = temporal_memory.get_entity_roster()
+ print("\n=== Entity Roster ===")
+ for entity in entities:
+ print(f" {entity['id']}: {entity['descriptor']}")
+
+ # Check graph database stats
+ graph_stats = temporal_memory.get_graph_db_stats()
+ print("\n=== Graph Database Stats ===")
+ if "error" in graph_stats:
+ print(f"Error: {graph_stats['error']}")
+ else:
+ print(f"Stats: {graph_stats['stats']}")
+ print(f"\nEntities in DB ({len(graph_stats['entities'])}):")
+ for entity in graph_stats["entities"]:
+ print(f" {entity['entity_id']} ({entity['entity_type']}): {entity['descriptor']}")
+ print(f"\nRecent relations ({len(graph_stats['recent_relations'])}):")
+ for rel in graph_stats["recent_relations"]:
+ print(
+ f" {rel['subject_id']} --{rel['relation_type']}--> {rel['object_id']} (confidence: {rel['confidence']:.2f})"
+ )
+
+ # Stop when done
+ temporal_memory.stop()
+ camera.stop()
+ print("\nTemporalMemory stopped")
+
+ finally:
+ if temporal_memory is not None:
+ temporal_memory.stop()
+ if camera is not None:
+ camera.stop()
+ if dimos is not None:
+ dimos.close_all() # type: ignore[attr-defined]
+
+
+if __name__ == "__main__":
+ example_usage()
diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py b/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py
new file mode 100644
index 0000000000..64950bee8a
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py
@@ -0,0 +1,60 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Temporal memory utilities for temporal memory. includes helper functions
+and prompts that are used to build the prompt for the VLM.
+"""
+
+# Re-export everything from submodules
+from .graph_utils import build_graph_context, extract_time_window
+from .helpers import clamp_text, format_timestamp, is_scene_stale, next_entity_id_hint
+from .parsers import parse_batch_distance_response, parse_window_response
+from .prompts import (
+ WINDOW_RESPONSE_SCHEMA,
+ build_batch_distance_estimation_prompt,
+ build_distance_estimation_prompt,
+ build_query_prompt,
+ build_summary_prompt,
+ build_window_prompt,
+ get_structured_output_format,
+)
+from .state import apply_summary_update, default_state, update_state_from_window
+
+__all__ = [
+ # Schema
+ "WINDOW_RESPONSE_SCHEMA",
+ # State management
+ "apply_summary_update",
+ # Prompts
+ "build_batch_distance_estimation_prompt",
+ "build_distance_estimation_prompt",
+ # Graph utils
+ "build_graph_context",
+ "build_query_prompt",
+ "build_summary_prompt",
+ "build_window_prompt",
+ # Helpers
+ "clamp_text",
+ "default_state",
+ "extract_time_window",
+ "format_timestamp",
+ "get_structured_output_format",
+ "is_scene_stale",
+ "next_entity_id_hint",
+ # Parsers
+ "parse_batch_distance_response",
+ "parse_window_response",
+ "update_state_from_window",
+]
diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py
new file mode 100644
index 0000000000..315d267a0c
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py
@@ -0,0 +1,206 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Graph database utility functions for temporal memory."""
+
+import re
+from typing import TYPE_CHECKING, Any
+
+from dimos.utils.logging_config import setup_logger
+
+if TYPE_CHECKING:
+ from dimos.models.vl.base import VlModel
+ from dimos.msgs.sensor_msgs import Image
+
+ from ..entity_graph_db import EntityGraphDB
+
+logger = setup_logger()
+
+
+def extract_time_window(
+ question: str,
+ vlm: "VlModel",
+ latest_frame: "Image | None" = None,
+) -> float | None:
+ """Extract time window from question using VLM with example-based learning.
+
+ Uses a few example keywords as patterns, then asks VLM to extrapolate
+ similar time references and return seconds.
+
+ Args:
+ question: User's question
+ vlm: VLM instance to use for extraction
+ latest_frame: Optional frame (required for VLM call, but image is ignored)
+
+ Returns:
+ Time window in seconds, or None if no time reference found
+ """
+ question_lower = question.lower()
+
+ # Quick check for common patterns (fast path)
+ if "last week" in question_lower or "past week" in question_lower:
+ return 7 * 24 * 3600
+ if "today" in question_lower or "last hour" in question_lower:
+ return 3600
+ if "recently" in question_lower or "recent" in question_lower:
+ return 600
+
+ # Use VLM to extract time reference from question
+ # Provide examples and let VLM extrapolate similar patterns
+ # Note: latest_frame is required by VLM interface but image content is ignored
+ if not latest_frame:
+ return None
+
+ extraction_prompt = f"""Extract any time reference from this question and convert it to seconds.
+
+Question: {question}
+
+Examples of time references and their conversions:
+- "last week" or "past week" -> 604800 seconds (7 days)
+- "yesterday" -> 86400 seconds (1 day)
+- "today" or "last hour" -> 3600 seconds (1 hour)
+- "recently" or "recent" -> 600 seconds (10 minutes)
+- "few minutes ago" -> 300 seconds (5 minutes)
+- "just now" -> 60 seconds (1 minute)
+
+Extrapolate similar patterns (e.g., "2 days ago", "this morning", "last month", etc.)
+and convert to seconds. If no time reference is found, return "none".
+
+Return ONLY a number (seconds) or the word "none". Do not include any explanation."""
+
+ try:
+ response = vlm.query(latest_frame, extraction_prompt)
+ response = response.strip().lower()
+
+ if "none" in response or not response:
+ return None
+
+ # Extract number from response
+ numbers = re.findall(r"\d+(?:\.\d+)?", response)
+ if numbers:
+ seconds = float(numbers[0])
+ # Sanity check: reasonable time windows (1 second to 1 year)
+ if 1 <= seconds <= 365 * 24 * 3600:
+ return seconds
+ except Exception as e:
+ logger.debug(f"Time extraction failed: {e}")
+
+ return None
+
+
+def build_graph_context(
+ graph_db: "EntityGraphDB",
+ entity_ids: list[str],
+ time_window_s: float | None = None,
+ max_relations_per_entity: int = 10,
+ nearby_distance_meters: float = 5.0,
+ current_video_time_s: float | None = None,
+) -> dict[str, Any]:
+ """Build enriched context from graph database for given entities.
+
+ Args:
+ graph_db: Entity graph database instance
+ entity_ids: List of entity IDs to get context for
+ time_window_s: Optional time window in seconds (e.g., 3600 for last hour)
+ max_relations_per_entity: Maximum relations to include per entity (default: 10)
+ nearby_distance_meters: Distance threshold for "nearby" entities (default: 5.0)
+ current_video_time_s: Current video timestamp in seconds (for time window queries).
+ If None, uses latest entity timestamp from DB as reference.
+
+ Returns:
+ Dictionary with graph context including relationships, distances, and semantics
+ """
+ if not graph_db or not entity_ids:
+ return {}
+
+ try:
+ graph_context: dict[str, Any] = {
+ "relationships": [],
+ "spatial_info": [],
+ "semantic_knowledge": [],
+ }
+
+ # Convert time_window_s to a (start_ts, end_ts) tuple if provided
+ # Use video-relative timestamps, not wall-clock time
+ time_window_tuple = None
+ if time_window_s is not None:
+ if current_video_time_s is not None:
+ ref_time = current_video_time_s
+ else:
+ # Fallback: get the latest timestamp from entities in DB
+ all_entities = graph_db.get_all_entities()
+ ref_time = max((e.get("last_seen_ts", 0) for e in all_entities), default=0)
+ time_window_tuple = (max(0, ref_time - time_window_s), ref_time)
+
+ # Get recent relationships for each entity
+ for entity_id in entity_ids:
+ # Get relationships (Graph 1: interactions)
+ relations = graph_db.get_relations_for_entity(
+ entity_id=entity_id,
+ relation_type=None, # all types
+ time_window=time_window_tuple,
+ )
+ for rel in relations[-max_relations_per_entity:]:
+ graph_context["relationships"].append(
+ {
+ "subject": rel["subject_id"],
+ "relation": rel["relation_type"],
+ "object": rel["object_id"],
+ "confidence": rel["confidence"],
+ "when": rel["timestamp_s"],
+ }
+ )
+
+ # Get spatial relationships (Graph 2: distances)
+ nearby = graph_db.get_nearby_entities(
+ entity_id=entity_id, max_distance=nearby_distance_meters, latest_only=True
+ )
+ for dist in nearby:
+ graph_context["spatial_info"].append(
+ {
+ "entity_a": entity_id,
+ "entity_b": dist["entity_id"],
+ "distance": dist.get("distance_meters"),
+ "category": dist.get("distance_category"),
+ "confidence": dist["confidence"],
+ }
+ )
+
+ # Get semantic knowledge (Graph 3: conceptual relations)
+ semantic_rels = graph_db.get_semantic_relations(
+ entity_id=entity_id,
+ relation_type=None,
+ )
+ for sem in semantic_rels:
+ graph_context["semantic_knowledge"].append(
+ {
+ "entity_a": sem["entity_a_id"],
+ "relation": sem["relation_type"],
+ "entity_b": sem["entity_b_id"],
+ "confidence": sem["confidence"],
+ "observations": sem["observation_count"],
+ }
+ )
+
+ # Get graph statistics for context
+ if entity_ids:
+ stats = graph_db.get_stats()
+ graph_context["total_entities"] = stats.get("entities", 0)
+ graph_context["total_relations"] = stats.get("relations", 0)
+
+ return graph_context
+
+ except Exception as e:
+ logger.warning(f"failed to build graph context: {e}")
+ return {}
diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py
new file mode 100644
index 0000000000..513feb65a4
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py
@@ -0,0 +1,74 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper utility functions for temporal memory."""
+
+from typing import TYPE_CHECKING, Any
+
+import numpy as np
+
+if TYPE_CHECKING:
+ from ..temporal_memory import Frame
+
+
+def next_entity_id_hint(roster: Any) -> str:
+ """Generate next entity ID based on existing roster (e.g., E1, E2, E3...)."""
+ if not isinstance(roster, list):
+ return "E1"
+ max_n = 0
+ for e in roster:
+ if not isinstance(e, dict):
+ continue
+ eid = e.get("id")
+ if isinstance(eid, str) and eid.startswith("E"):
+ tail = eid[1:]
+ if tail.isdigit():
+ max_n = max(max_n, int(tail))
+ return f"E{max_n + 1}"
+
+
+def clamp_text(text: str, max_chars: int) -> str:
+ """Clamp text to maximum characters."""
+ if len(text) <= max_chars:
+ return text
+ return text[:max_chars] + "..."
+
+
+def format_timestamp(seconds: float) -> str:
+ """Format seconds as MM:SS.mmm timestamp string."""
+ m = int(seconds // 60)
+ s = seconds - 60 * m
+ return f"{m:02d}:{s:06.3f}"
+
+
+def is_scene_stale(frames: list["Frame"], stale_threshold: float = 5.0) -> bool:
+ """Check if scene hasn't changed meaningfully between first and last frame.
+
+ Args:
+ frames: List of frames to check
+ stale_threshold: Threshold for mean pixel difference (default: 5.0)
+
+ Returns:
+ True if scene is stale (hasn't changed enough), False otherwise
+ """
+ if len(frames) < 2:
+ return False
+ first_img = frames[0].image
+ last_img = frames[-1].image
+ if first_img is None or last_img is None:
+ return False
+ if not hasattr(first_img, "data") or not hasattr(last_img, "data"):
+ return False
+ diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float))
+ return bool(diff.mean() < stale_threshold)
diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py b/dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py
new file mode 100644
index 0000000000..a9b1a05d9f
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py
@@ -0,0 +1,156 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Response parsing functions for VLM outputs."""
+
+from typing import Any
+
+from dimos.utils.llm_utils import extract_json
+
+
+def parse_batch_distance_response(
+ response: str, entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]]
+) -> list[dict[str, Any]]:
+ """
+ Parse batched distance estimation response.
+
+ Args:
+ response: VLM response text
+ entity_pairs: Original entity pairs used in the prompt
+
+ Returns:
+ List of dicts with keys: entity_a_id, entity_b_id, category, distance_m, confidence
+ """
+ results = []
+ lines = response.strip().split("\n")
+
+ current_pair_idx = None
+ category = None
+ distance_m = None
+ confidence = 0.5
+
+ for line in lines:
+ line = line.strip()
+
+ # Check for pair marker
+ if line.startswith("Pair "):
+ # Save previous pair if exists
+ if current_pair_idx is not None and category:
+ entity_a, entity_b = entity_pairs[current_pair_idx]
+ results.append(
+ {
+ "entity_a_id": entity_a["id"],
+ "entity_b_id": entity_b["id"],
+ "category": category,
+ "distance_m": distance_m,
+ "confidence": confidence,
+ }
+ )
+
+ # Start new pair
+ try:
+ pair_num = int(line.split()[1].rstrip(":"))
+ current_pair_idx = pair_num - 1 # Convert to 0-indexed
+ category = None
+ distance_m = None
+ confidence = 0.5
+ except (IndexError, ValueError):
+ continue
+
+ # Parse distance fields
+ elif line.startswith("category:"):
+ category = line.split(":", 1)[1].strip().lower()
+ elif line.startswith("distance_m:"):
+ try:
+ distance_m = float(line.split(":", 1)[1].strip())
+ except (ValueError, IndexError):
+ pass
+ elif line.startswith("confidence:"):
+ try:
+ confidence = float(line.split(":", 1)[1].strip())
+ except (ValueError, IndexError):
+ pass
+
+ # Save last pair
+ if current_pair_idx is not None and category and current_pair_idx < len(entity_pairs):
+ entity_a, entity_b = entity_pairs[current_pair_idx]
+ results.append(
+ {
+ "entity_a_id": entity_a["id"],
+ "entity_b_id": entity_b["id"],
+ "category": category,
+ "distance_m": distance_m,
+ "confidence": confidence,
+ }
+ )
+
+ return results
+
+
+def parse_window_response(
+ response_text: str, w_start: float, w_end: float, frame_count: int
+) -> dict[str, Any]:
+ """
+ Parse VLM response for a window analysis.
+
+ Args:
+ response_text: Raw text response from VLM
+ w_start: Window start time
+ w_end: Window end time
+ frame_count: Number of frames in window
+
+ Returns:
+ Parsed dictionary with defaults filled in. If parsing fails, returns
+ a dict with "_error" key instead of raising.
+ """
+ # Try to extract JSON (handles code fences)
+ parsed = extract_json(response_text)
+ if parsed is None:
+ return {
+ "window": {"start_s": w_start, "end_s": w_end},
+ "caption": "",
+ "entities_present": [],
+ "new_entities": [],
+ "relations": [],
+ "on_screen_text": [],
+ "_error": f"Failed to parse JSON from response: {response_text[:200]}...",
+ }
+
+ # Ensure we return a dict (extract_json can return a list)
+ if isinstance(parsed, list):
+ # If we got a list, wrap it in a dict with a default structure
+ # This shouldn't happen with proper structured output, but handle gracefully
+ return {
+ "window": {"start_s": w_start, "end_s": w_end},
+ "caption": "",
+ "entities_present": [],
+ "new_entities": [],
+ "relations": [],
+ "on_screen_text": [],
+ "_error": f"Unexpected list response: {parsed}",
+ }
+
+ # Ensure it's a dict
+ if not isinstance(parsed, dict):
+ return {
+ "window": {"start_s": w_start, "end_s": w_end},
+ "caption": "",
+ "entities_present": [],
+ "new_entities": [],
+ "relations": [],
+ "on_screen_text": [],
+ "_error": f"Expected dict or list, got {type(parsed)}: {parsed}",
+ }
+
+ return parsed
diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py b/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py
new file mode 100644
index 0000000000..61399fd3f1
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py
@@ -0,0 +1,353 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Prompt building functions for VLM queries."""
+
+import json
+from typing import Any
+
+from .helpers import clamp_text, next_entity_id_hint
+
+# JSON schema for window responses (from VideoRAG)
+WINDOW_RESPONSE_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "window": {
+ "type": "object",
+ "properties": {"start_s": {"type": "number"}, "end_s": {"type": "number"}},
+ "required": ["start_s", "end_s"],
+ },
+ "caption": {"type": "string"},
+ "entities_present": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "id": {"type": "string"},
+ "confidence": {"type": "number", "minimum": 0, "maximum": 1},
+ },
+ "required": ["id"],
+ },
+ },
+ "new_entities": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "id": {"type": "string"},
+ "type": {
+ "type": "string",
+ "enum": ["person", "object", "screen", "text", "location", "other"],
+ },
+ "descriptor": {"type": "string"},
+ },
+ "required": ["id", "type"],
+ },
+ },
+ "relations": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "type": {"type": "string"},
+ "subject": {"type": "string"},
+ "object": {"type": "string"},
+ "confidence": {"type": "number", "minimum": 0, "maximum": 1},
+ "evidence": {"type": "array", "items": {"type": "string"}},
+ "notes": {"type": "string"},
+ },
+ "required": ["type", "subject", "object"],
+ },
+ },
+ "on_screen_text": {"type": "array", "items": {"type": "string"}},
+ "uncertainties": {"type": "array", "items": {"type": "string"}},
+ "confidence": {"type": "number", "minimum": 0, "maximum": 1},
+ },
+ "required": ["window", "caption"],
+}
+
+
+def build_window_prompt(
+ *,
+ w_start: float,
+ w_end: float,
+ frame_count: int,
+ state: dict[str, Any],
+) -> str:
+ """
+ Build comprehensive VLM prompt for analyzing a video window.
+
+ This is adapted from videorag's build_window_messages() but formatted
+ as a single text prompt for VlModel.query() instead of OpenAI's messages format.
+
+ Args:
+ w_start: Window start time in seconds
+ w_end: Window end time in seconds
+ frame_count: Number of frames in this window
+ state: Current temporal memory state (entity_roster, rolling_summary, etc.)
+
+ Returns:
+ Formatted prompt string
+ """
+ roster = state.get("entity_roster", [])
+ rolling_summary = state.get("rolling_summary", "")
+ next_id = next_entity_id_hint(roster)
+
+ # System instructions (from VideoRAG)
+ system_context = """You analyze short sequences of video frames.
+You must stay grounded in what is visible.
+Do not identify real people or guess names/identities; describe people anonymously.
+Extract general entities (people, objects, screens, text, locations) and relations between them.
+Use stable entity IDs like E1, E2 based on the provided roster."""
+
+ # Main prompt (from VideoRAG's build_window_messages)
+ prompt = f"""{system_context}
+
+Time window: [{w_start:.3f}, {w_end:.3f}) seconds
+Number of frames: {frame_count}
+
+Existing entity roster (may be empty):
+{json.dumps(roster, ensure_ascii=False)}
+
+Rolling summary so far (may be empty):
+{clamp_text(str(rolling_summary), 1500)}
+
+Task:
+1) Write a dense, grounded caption describing what is visible across the frames in this time window.
+2) Identify which existing roster entities appear in these frames.
+3) Add any new salient entities (people/objects/screens/text/locations) with a short grounded descriptor.
+4) Extract grounded relations/events between entities (e.g., looks_at, holds, uses, walks_past, speaks_to (inferred)).
+
+New entity IDs must start at: {next_id}
+
+Rules (important):
+- You MUST stay grounded in what is visible in the provided frames.
+- You MUST NOT mention any entity ID unless it appears in the provided roster OR you include it in new_entities in this same output.
+- If the roster is empty, introduce any salient entities you reference (start with E1, E2, ...).
+- Do not invent on-screen text: only include text you can read.
+- If a relation is inferred (e.g., speaks_to without audio), include it but lower confidence and explain the visual cues.
+
+Output JSON ONLY with this schema:
+{{
+ "window": {{"start_s": {w_start:.3f}, "end_s": {w_end:.3f}}},
+ "caption": "dense grounded description",
+ "entities_present": [{{"id": "E1", "confidence": 0.0-1.0}}],
+ "new_entities": [{{"id": "E3", "type": "person|object|screen|text|location|other", "descriptor": "..."}}],
+ "relations": [
+ {{
+ "type": "speaks_to|looks_at|holds|uses|moves|gesture|scene_change|other",
+ "subject": "E1|unknown",
+ "object": "E2|unknown",
+ "confidence": 0.0-1.0,
+ "evidence": ["describe which frames show this"],
+ "notes": "short, grounded"
+ }}
+ ],
+ "on_screen_text": ["verbatim snippets"],
+ "uncertainties": ["things that are unclear"],
+ "confidence": 0.0-1.0
+}}
+"""
+ return prompt
+
+
+def build_summary_prompt(
+ *,
+ rolling_summary: str,
+ chunk_windows: list[dict[str, Any]],
+) -> str:
+ """
+ Build prompt for updating rolling summary.
+
+ This is adapted from videorag's build_summary_messages() but formatted
+ as a single text prompt for VlModel.query().
+
+ Args:
+ rolling_summary: Current rolling summary text
+ chunk_windows: List of recent window results to incorporate
+
+ Returns:
+ Formatted prompt string
+ """
+ # System context (from VideoRAG)
+ system_context = """You summarize timestamped video-window logs into a concise rolling summary.
+Stay grounded in the provided window captions/relations.
+Do not invent entities or rename entity IDs; preserve IDs like E1, E2 exactly.
+You MAY incorporate new entity IDs if they appear in the provided chunk windows (e.g., in new_entities).
+Be concise, but keep relevant entity continuity and key relations."""
+
+ prompt = f"""{system_context}
+
+Update the rolling summary using the newest chunk.
+
+Previous rolling summary (may be empty):
+{clamp_text(rolling_summary, 2500)}
+
+New chunk windows (JSON):
+{json.dumps(chunk_windows, ensure_ascii=False)}
+
+Output a concise summary as PLAIN TEXT (no JSON, no code fences).
+Length constraints (important):
+- Target <= 120 words total.
+- Hard cap <= 900 characters.
+"""
+ return prompt
+
+
+def build_query_prompt(
+ *,
+ question: str,
+ context: dict[str, Any],
+) -> str:
+ """
+ Build prompt for querying temporal memory.
+
+ Args:
+ question: User's question about the video stream
+ context: Context dict containing entity_roster, rolling_summary, etc.
+
+ Returns:
+ Formatted prompt string
+ """
+ currently_present = context.get("currently_present_entities", [])
+ currently_present_str = (
+ f"Entities recently detected in recent windows: {currently_present}"
+ if currently_present
+ else "No entities were detected in recent windows (list is empty)"
+ )
+
+ prompt = f"""Answer the following question about the video stream using the provided context.
+
+**Question:** {question}
+
+**Context:**
+{json.dumps(context, indent=2, ensure_ascii=False)}
+
+**Important Notes:**
+- Entities have stable IDs like E1, E2, etc.
+- The 'currently_present_entities' list contains entity IDs that were detected in recent video windows (not necessarily in the current frame you're viewing)
+- {currently_present_str}
+- The 'entity_roster' contains all known entities with their descriptions
+- The 'rolling_summary' describes what has happened over time
+- If 'currently_present_entities' is empty, it means no entities were detected in recent windows, but entities may still exist in the roster from earlier
+- Answer based on the provided context (entity_roster, rolling_summary, currently_present_entities) AND what you see in the current frame
+- If the context says entities were present but you don't see them in the current frame, mention both: what was recently detected AND what you currently see
+
+Provide a concise answer.
+"""
+ return prompt
+
+
+def build_distance_estimation_prompt(
+ *,
+ entity_a_descriptor: str,
+ entity_a_id: str,
+ entity_b_descriptor: str,
+ entity_b_id: str,
+) -> str:
+ """
+ Build prompt for estimating distance between two entities.
+
+ Args:
+ entity_a_descriptor: Description of first entity
+ entity_a_id: ID of first entity
+ entity_b_descriptor: Description of second entity
+ entity_b_id: ID of second entity
+
+ Returns:
+ Formatted prompt string for distance estimation
+ """
+ prompt = f"""Look at this image and estimate the distance between these two entities:
+
+Entity A: {entity_a_descriptor} (ID: {entity_a_id})
+Entity B: {entity_b_descriptor} (ID: {entity_b_id})
+
+Provide:
+1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m)
+2. Approximate distance in meters (best guess)
+3. Confidence: 0.0-1.0 (how certain are you?)
+
+Respond in this format:
+category: [near/medium/far]
+distance_m: [number]
+confidence: [0.0-1.0]
+reasoning: [brief explanation]"""
+ return prompt
+
+
+def build_batch_distance_estimation_prompt(
+ entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]],
+) -> str:
+ """
+ Build prompt for estimating distances between multiple entity pairs in one call.
+
+ Args:
+ entity_pairs: List of (entity_a, entity_b) tuples, each entity is a dict with 'id' and 'descriptor'
+
+ Returns:
+ Formatted prompt string for batched distance estimation
+ """
+ pairs_text = []
+ for i, (entity_a, entity_b) in enumerate(entity_pairs, 1):
+ pairs_text.append(
+ f"Pair {i}:\n"
+ f" Entity A: {entity_a['descriptor']} (ID: {entity_a['id']})\n"
+ f" Entity B: {entity_b['descriptor']} (ID: {entity_b['id']})"
+ )
+
+ prompt = f"""Look at this image and estimate the distances between the following entity pairs:
+
+{chr(10).join(pairs_text)}
+
+For each pair, provide:
+1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m)
+2. Approximate distance in meters (best guess)
+3. Confidence: 0.0-1.0 (how certain are you?)
+
+Respond in this format (one block per pair):
+Pair 1:
+category: [near/medium/far]
+distance_m: [number]
+confidence: [0.0-1.0]
+
+Pair 2:
+category: [near/medium/far]
+distance_m: [number]
+confidence: [0.0-1.0]
+
+(etc.)"""
+ return prompt
+
+
+def get_structured_output_format() -> dict[str, Any]:
+ """
+ Get OpenAI-compatible structured output format for window responses.
+
+ This uses the json_schema mode available in OpenAI API (GPT-4o mini) to enforce
+ the VideoRAG response schema.
+
+ Returns:
+ Dictionary for response_format parameter:
+ {"type": "json_schema", "json_schema": {...}}
+ """
+
+ return {
+ "type": "json_schema",
+ "json_schema": {
+ "name": "video_window_analysis",
+ "description": "Analysis of a video window with entities and relations",
+ "schema": WINDOW_RESPONSE_SCHEMA,
+ "strict": False, # Allow additional fields
+ },
+ }
diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/state.py b/dimos/perception/experimental/temporal_memory/temporal_utils/state.py
new file mode 100644
index 0000000000..9cdfbe4931
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/temporal_utils/state.py
@@ -0,0 +1,139 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""State management functions for temporal memory."""
+
+from typing import Any
+
+
+def default_state() -> dict[str, Any]:
+ """Create default temporal memory state dictionary."""
+ return {
+ "entity_roster": [],
+ "rolling_summary": "",
+ "chunk_buffer": [],
+ "next_summary_at_s": 0.0,
+ "last_present": [],
+ }
+
+
+def update_state_from_window(
+ state: dict[str, Any],
+ parsed: dict[str, Any],
+ w_end: float,
+ summary_interval_s: float,
+) -> bool:
+ """
+ Update temporal memory state from a parsed window result.
+
+ This implements the state update logic from VideoRAG's generate_evidence().
+
+ Args:
+ state: Current state dictionary (modified in place)
+ parsed: Parsed window result
+ w_end: Window end time
+ summary_interval_s: How often to trigger summary updates
+
+ Returns:
+ True if summary update is needed, False otherwise
+ """
+ # Skip if there was an error
+ if "_error" in parsed:
+ return False
+
+ new_entities = parsed.get("new_entities", [])
+ present = parsed.get("entities_present", [])
+
+ # Handle new entities
+ if new_entities:
+ roster = list(state.get("entity_roster", []))
+ known = {e.get("id") for e in roster if isinstance(e, dict)}
+ for e in new_entities:
+ if isinstance(e, dict) and e.get("id") not in known:
+ roster.append(e)
+ known.add(e.get("id"))
+ state["entity_roster"] = roster
+
+ # Handle referenced entities (auto-add if mentioned but not in roster)
+ roster = list(state.get("entity_roster", []))
+ known = {e.get("id") for e in roster if isinstance(e, dict)}
+ referenced: set[str] = set()
+ for p in present or []:
+ if isinstance(p, dict) and isinstance(p.get("id"), str):
+ referenced.add(p["id"])
+ for rel in parsed.get("relations") or []:
+ if isinstance(rel, dict):
+ for k in ("subject", "object"):
+ v = rel.get(k)
+ if isinstance(v, str) and v != "unknown":
+ referenced.add(v)
+ for rid in sorted(referenced):
+ if rid not in known:
+ roster.append(
+ {
+ "id": rid,
+ "type": "other",
+ "descriptor": "unknown (auto-added; rerun recommended)",
+ }
+ )
+ known.add(rid)
+ state["entity_roster"] = roster
+ state["last_present"] = present
+
+ # Add to chunk buffer
+ chunk_buffer = state.get("chunk_buffer", [])
+ if not isinstance(chunk_buffer, list):
+ chunk_buffer = []
+ chunk_buffer.append(
+ {
+ "window": parsed.get("window"),
+ "caption": parsed.get("caption", ""),
+ "entities_present": parsed.get("entities_present", []),
+ "new_entities": parsed.get("new_entities", []),
+ "relations": parsed.get("relations", []),
+ "on_screen_text": parsed.get("on_screen_text", []),
+ }
+ )
+ state["chunk_buffer"] = chunk_buffer
+
+ # Check if summary update is needed
+ if summary_interval_s > 0:
+ next_at = float(state.get("next_summary_at_s", summary_interval_s))
+ if w_end + 1e-6 >= next_at and chunk_buffer:
+ return True # Need to update summary
+
+ return False
+
+
+def apply_summary_update(
+ state: dict[str, Any], summary_text: str, w_end: float, summary_interval_s: float
+) -> None:
+ """
+ Apply a summary update to the state.
+
+ Args:
+ state: State dictionary (modified in place)
+ summary_text: New summary text
+ w_end: Current window end time
+ summary_interval_s: Summary update interval
+ """
+ if summary_text and summary_text.strip():
+ state["rolling_summary"] = summary_text.strip()
+ state["chunk_buffer"] = []
+
+ # Advance next_summary_at_s
+ next_at = float(state.get("next_summary_at_s", summary_interval_s))
+ while next_at <= w_end + 1e-6:
+ next_at += float(summary_interval_s)
+ state["next_summary_at_s"] = next_at
diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py
new file mode 100644
index 0000000000..7b38e4ce40
--- /dev/null
+++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py
@@ -0,0 +1,230 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+import pathlib
+import tempfile
+import time
+
+from dotenv import load_dotenv
+import pytest
+from reactivex import operators as ops
+
+from dimos import core
+from dimos.core import Module, Out, rpc
+from dimos.models.vl.openai import OpenAIVlModel
+from dimos.msgs.sensor_msgs import Image
+from dimos.perception.experimental.temporal_memory import TemporalMemory, TemporalMemoryConfig
+from dimos.protocol import pubsub
+from dimos.utils.data import get_data
+from dimos.utils.logging_config import setup_logger
+from dimos.utils.testing import TimedSensorReplay
+
+# Load environment variables
+load_dotenv()
+
+logger = setup_logger()
+
+pubsub.lcm.autoconf()
+
+
+class VideoReplayModule(Module):
+ """Module that replays video data from TimedSensorReplay."""
+
+ video_out: Out[Image]
+
+ def __init__(self, video_path: str) -> None:
+ super().__init__()
+ self.video_path = video_path
+
+ @rpc
+ def start(self) -> None:
+ """Start replaying video data."""
+ # Use TimedSensorReplay to replay video frames
+ video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy)
+
+ # Subscribe to the replay stream and publish to LCM
+ self._disposables.add(
+ video_replay.stream()
+ .pipe(
+ ops.sample(1), # Sample every 1 second
+ ops.take(10), # Only take 10 frames total
+ )
+ .subscribe(self.video_out.publish)
+ )
+
+ logger.info("VideoReplayModule started")
+
+ @rpc
+ def stop(self) -> None:
+ """Stop replaying video data."""
+ # Stop all stream transports to clean up LCM loop threads
+ for stream in list(self.outputs.values()):
+ if stream.transport is not None and hasattr(stream.transport, "stop"):
+ stream.transport.stop()
+ stream._transport = None
+ super().stop()
+ logger.info("VideoReplayModule stopped")
+
+
+@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM replay + dataset not CI-safe.")
+@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.")
+@pytest.mark.neverending
+class TestTemporalMemoryModule:
+ @pytest.fixture(scope="function")
+ def temp_dir(self):
+ """Create a temporary directory for test data."""
+ temp_dir = tempfile.mkdtemp(prefix="temporal_memory_test_")
+ yield temp_dir
+
+ @pytest.fixture(scope="function")
+ def dimos_cluster(self):
+ """Create and cleanup Dimos cluster."""
+ dimos = core.start(1)
+ yield dimos
+ dimos.close_all()
+
+ @pytest.fixture(scope="function")
+ def video_module(self, dimos_cluster):
+ """Create and cleanup video replay module."""
+ data_path = get_data("unitree_office_walk")
+ video_path = os.path.join(data_path, "video")
+ video_module = dimos_cluster.deploy(VideoReplayModule, video_path)
+ video_module.video_out.transport = core.LCMTransport("/test_video", Image)
+ yield video_module
+ try:
+ video_module.stop()
+ except Exception as e:
+ logger.warning(f"Failed to stop video_module: {e}")
+
+ @pytest.fixture(scope="function")
+ def temporal_memory(self, dimos_cluster, temp_dir):
+ """Create and cleanup temporal memory module."""
+ output_dir = os.path.join(temp_dir, "temporal_memory_output")
+ # Create OpenAIVlModel instance
+ api_key = os.getenv("OPENAI_API_KEY")
+ vlm = OpenAIVlModel(api_key=api_key)
+
+ temporal_memory = dimos_cluster.deploy(
+ TemporalMemory,
+ vlm=vlm,
+ config=TemporalMemoryConfig(
+ fps=1.0, # Process 1 frame per second
+ window_s=2.0, # Analyze 2-second windows
+ stride_s=2.0, # New window every 2 seconds
+ summary_interval_s=10.0, # Update rolling summary every 10 seconds
+ max_frames_per_window=3, # Max 3 frames per window
+ output_dir=output_dir,
+ ),
+ )
+ yield temporal_memory
+ try:
+ temporal_memory.stop()
+ except Exception as e:
+ logger.warning(f"Failed to stop temporal_memory: {e}")
+
+ @pytest.mark.asyncio
+ async def test_temporal_memory_module_with_replay(
+ self, dimos_cluster, video_module, temporal_memory, temp_dir
+ ):
+ """Test TemporalMemory module with TimedSensorReplay inputs."""
+ # Connect streams
+ temporal_memory.color_image.connect(video_module.video_out)
+
+ # Start all modules
+ video_module.start()
+ temporal_memory.start()
+ logger.info("All modules started, processing in background...")
+
+ # Wait for frames to be processed with timeout
+ timeout = 15.0 # 15 second timeout
+ start_time = time.time()
+
+ # Keep checking state while modules are running
+ while (time.time() - start_time) < timeout:
+ state = temporal_memory.get_state()
+ if state["frame_count"] > 0:
+ logger.info(
+ f"Frames processing - Frame count: {state['frame_count']}, "
+ f"Buffer size: {state['buffer_size']}, "
+ f"Entity count: {state['entity_count']}"
+ )
+ if state["frame_count"] >= 3: # Wait for at least 3 frames
+ break
+ await asyncio.sleep(0.5)
+ else:
+ # Timeout reached
+ state = temporal_memory.get_state()
+ logger.error(
+ f"Timeout after {timeout}s - Frame count: {state['frame_count']}, "
+ f"Buffer size: {state['buffer_size']}"
+ )
+ raise AssertionError(f"No frames processed within {timeout} seconds")
+
+ await asyncio.sleep(3) # Wait for more processing
+
+ # Test get_state() RPC method
+ mid_state = temporal_memory.get_state()
+ logger.info(
+ f"Mid-test state - Frame count: {mid_state['frame_count']}, "
+ f"Entity count: {mid_state['entity_count']}, "
+ f"Recent windows: {mid_state['recent_windows']}"
+ )
+ assert mid_state["frame_count"] >= state["frame_count"], (
+ "Frame count should increase or stay same"
+ )
+
+ # Test query() RPC method
+ answer = temporal_memory.query("What entities are currently visible?")
+ logger.info(f"Query result: {answer[:200]}...")
+ assert len(answer) > 0, "Query should return a non-empty answer"
+
+ # Test get_entity_roster() RPC method
+ entities = temporal_memory.get_entity_roster()
+ logger.info(f"Entity roster has {len(entities)} entities")
+ assert isinstance(entities, list), "Entity roster should be a list"
+
+ # Test get_rolling_summary() RPC method
+ summary = temporal_memory.get_rolling_summary()
+ logger.info(f"Rolling summary: {summary[:200] if summary else 'empty'}...")
+ assert isinstance(summary, str), "Rolling summary should be a string"
+
+ final_state = temporal_memory.get_state()
+ logger.info(
+ f"Final state - Frame count: {final_state['frame_count']}, "
+ f"Entity count: {final_state['entity_count']}, "
+ f"Recent windows: {final_state['recent_windows']}"
+ )
+
+ video_module.stop()
+ temporal_memory.stop()
+ logger.info("Stopped modules")
+
+ # Wait a bit for file operations to complete
+ await asyncio.sleep(0.5)
+
+ # Verify files were created - stop() already saved them
+ output_dir = os.path.join(temp_dir, "temporal_memory_output")
+ output_path = pathlib.Path(output_dir)
+ assert output_path.exists(), f"Output directory should exist: {output_dir}"
+ assert (output_path / "state.json").exists(), "state.json should exist"
+ assert (output_path / "entities.json").exists(), "entities.json should exist"
+ assert (output_path / "frames_index.jsonl").exists(), "frames_index.jsonl should exist"
+
+ logger.info("All temporal memory module tests passed!")
+
+
+if __name__ == "__main__":
+ pytest.main(["-v", "-s", __file__])
diff --git a/dimos/perception/grasp_generation/__init__.py b/dimos/perception/grasp_generation/__init__.py
deleted file mode 100644
index 16281fe0b6..0000000000
--- a/dimos/perception/grasp_generation/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .utils import *
diff --git a/dimos/perception/grasp_generation/grasp_generation.py b/dimos/perception/grasp_generation/grasp_generation.py
deleted file mode 100644
index 4f2e4b68a1..0000000000
--- a/dimos/perception/grasp_generation/grasp_generation.py
+++ /dev/null
@@ -1,233 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Dimensional-hosted grasp generation for manipulation pipeline.
-"""
-
-import asyncio
-
-import numpy as np
-import open3d as o3d # type: ignore[import-untyped]
-
-from dimos.perception.grasp_generation.utils import parse_grasp_results
-from dimos.types.manipulation import ObjectData
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class HostedGraspGenerator:
- """
- Dimensional-hosted grasp generator using WebSocket communication.
- """
-
- def __init__(self, server_url: str) -> None:
- """
- Initialize Dimensional-hosted grasp generator.
-
- Args:
- server_url: WebSocket URL for Dimensional-hosted grasp generator server
- """
- self.server_url = server_url
- logger.info(f"Initialized grasp generator with server: {server_url}")
-
- def generate_grasps_from_objects(
- self, objects: list[ObjectData], full_pcd: o3d.geometry.PointCloud
- ) -> list[dict]: # type: ignore[type-arg]
- """
- Generate grasps from ObjectData objects using grasp generator.
-
- Args:
- objects: List of ObjectData with point clouds
- full_pcd: Open3D point cloud of full scene
-
- Returns:
- Parsed grasp results as list of dictionaries
- """
- try:
- # Combine all point clouds
- all_points = []
- all_colors = []
- valid_objects = 0
-
- for obj in objects:
- if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None:
- continue
-
- points = obj["point_cloud_numpy"]
- if not isinstance(points, np.ndarray) or points.size == 0:
- continue
-
- if len(points.shape) != 2 or points.shape[1] != 3:
- continue
-
- colors = None
- if "colors_numpy" in obj and obj["colors_numpy"] is not None: # type: ignore[typeddict-item]
- colors = obj["colors_numpy"] # type: ignore[typeddict-item]
- if isinstance(colors, np.ndarray) and colors.size > 0:
- if (
- colors.shape[0] != points.shape[0]
- or len(colors.shape) != 2
- or colors.shape[1] != 3
- ):
- colors = None
-
- all_points.append(points)
- if colors is not None:
- all_colors.append(colors)
- valid_objects += 1
-
- if not all_points:
- return []
-
- # Combine point clouds
- combined_points = np.vstack(all_points)
- combined_colors = None
- if len(all_colors) == valid_objects and len(all_colors) > 0:
- combined_colors = np.vstack(all_colors)
-
- # Send grasp request
- grasps = self._send_grasp_request_sync(combined_points, combined_colors)
-
- if not grasps:
- return []
-
- # Parse and return results in list of dictionaries format
- return parse_grasp_results(grasps)
-
- except Exception as e:
- logger.error(f"Grasp generation failed: {e}")
- return []
-
- def _send_grasp_request_sync(
- self,
- points: np.ndarray, # type: ignore[type-arg]
- colors: np.ndarray | None, # type: ignore[type-arg]
- ) -> list[dict] | None: # type: ignore[type-arg]
- """Send synchronous grasp request to grasp server."""
-
- try:
- # Prepare colors
- colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5
-
- # Ensure correct data types
- points = points.astype(np.float32)
- colors = colors.astype(np.float32)
-
- # Validate ranges
- if np.any(np.isnan(points)) or np.any(np.isinf(points)):
- logger.error("Points contain NaN or Inf values")
- return None
- if np.any(np.isnan(colors)) or np.any(np.isinf(colors)):
- logger.error("Colors contain NaN or Inf values")
- return None
-
- colors = np.clip(colors, 0.0, 1.0)
-
- # Run async request in sync context
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- result = loop.run_until_complete(self._async_grasp_request(points, colors))
- return result
- finally:
- loop.close()
-
- except Exception as e:
- logger.error(f"Error in synchronous grasp request: {e}")
- return None
-
- async def _async_grasp_request(
- self,
- points: np.ndarray, # type: ignore[type-arg]
- colors: np.ndarray, # type: ignore[type-arg]
- ) -> list[dict] | None: # type: ignore[type-arg]
- """Async grasp request helper."""
- import json
-
- import websockets
-
- try:
- async with websockets.connect(self.server_url) as websocket:
- request = {
- "points": points.tolist(),
- "colors": colors.tolist(),
- "lims": [-1.0, 1.0, -1.0, 1.0, 0.0, 2.0],
- }
-
- await websocket.send(json.dumps(request))
- response = await websocket.recv()
- grasps = json.loads(response)
-
- if isinstance(grasps, dict) and "error" in grasps:
- logger.error(f"Server returned error: {grasps['error']}")
- return None
- elif isinstance(grasps, int | float) and grasps == 0:
- return None
- elif not isinstance(grasps, list):
- logger.error(f"Server returned unexpected response type: {type(grasps)}")
- return None
- elif len(grasps) == 0:
- return None
-
- return self._convert_grasp_format(grasps)
-
- except Exception as e:
- logger.error(f"Async grasp request failed: {e}")
- return None
-
- def _convert_grasp_format(self, grasps: list[dict]) -> list[dict]: # type: ignore[type-arg]
- """Convert Dimensional Grasp format to visualization format."""
- converted = []
-
- for i, grasp in enumerate(grasps):
- rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3)))
- euler_angles = self._rotation_matrix_to_euler(rotation_matrix)
-
- converted_grasp = {
- "id": f"grasp_{i}",
- "score": grasp.get("score", 0.0),
- "width": grasp.get("width", 0.0),
- "height": grasp.get("height", 0.0),
- "depth": grasp.get("depth", 0.0),
- "translation": grasp.get("translation", [0, 0, 0]),
- "rotation_matrix": rotation_matrix.tolist(),
- "euler_angles": euler_angles,
- }
- converted.append(converted_grasp)
-
- converted.sort(key=lambda x: x["score"], reverse=True)
- return converted
-
- def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> dict[str, float]: # type: ignore[type-arg]
- """Convert rotation matrix to Euler angles (in radians)."""
- sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2)
-
- singular = sy < 1e-6
-
- if not singular:
- x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2])
- y = np.arctan2(-rotation_matrix[2, 0], sy)
- z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0])
- else:
- x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1])
- y = np.arctan2(-rotation_matrix[2, 0], sy)
- z = 0
-
- return {"roll": x, "pitch": y, "yaw": z}
-
- def cleanup(self) -> None:
- """Clean up resources."""
- logger.info("Grasp generator cleaned up")
diff --git a/dimos/perception/grasp_generation/utils.py b/dimos/perception/grasp_generation/utils.py
deleted file mode 100644
index 492a3d1df4..0000000000
--- a/dimos/perception/grasp_generation/utils.py
+++ /dev/null
@@ -1,529 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Utilities for grasp generation and visualization."""
-
-import cv2
-import numpy as np
-import open3d as o3d # type: ignore[import-untyped]
-
-from dimos.perception.common.utils import project_3d_points_to_2d
-
-
-def create_gripper_geometry(
- grasp_data: dict, # type: ignore[type-arg]
- finger_length: float = 0.08,
- finger_thickness: float = 0.004,
-) -> list[o3d.geometry.TriangleMesh]:
- """
- Create a simple fork-like gripper geometry from grasp data.
-
- Args:
- grasp_data: Dictionary containing grasp parameters
- - translation: 3D position list
- - rotation_matrix: 3x3 rotation matrix defining gripper coordinate system
- * X-axis: gripper width direction (opening/closing)
- * Y-axis: finger length direction
- * Z-axis: approach direction (toward object)
- - width: Gripper opening width
- finger_length: Length of gripper fingers (longer)
- finger_thickness: Thickness of gripper fingers
- base_height: Height of gripper base (longer)
- color: RGB color for the gripper (solid blue)
-
- Returns:
- List of Open3D TriangleMesh geometries for the gripper
- """
-
- translation = np.array(grasp_data["translation"])
- rotation_matrix = np.array(grasp_data["rotation_matrix"])
-
- width = grasp_data.get("width", 0.04)
-
- # Create transformation matrix
- transform = np.eye(4)
- transform[:3, :3] = rotation_matrix
- transform[:3, 3] = translation
-
- geometries = []
-
- # Gripper dimensions
- finger_width = 0.006 # Thickness of each finger
- handle_length = 0.05 # Length of handle extending backward
-
- # Build gripper in local coordinate system:
- # X-axis = width direction (left/right finger separation)
- # Y-axis = finger length direction (fingers extend along +Y)
- # Z-axis = approach direction (toward object, handle extends along -Z)
- # IMPORTANT: Fingertips should be at origin (translation point)
-
- # Create left finger extending along +Y, positioned at +X
- left_finger = o3d.geometry.TriangleMesh.create_box(
- width=finger_width, # Thin finger
- height=finger_length, # Extends along Y (finger length direction)
- depth=finger_thickness, # Thin in Z direction
- )
- left_finger.translate(
- [
- width / 2 - finger_width / 2, # Position at +X (half width from center)
- -finger_length, # Shift so fingertips are at origin
- -finger_thickness / 2, # Center in Z
- ]
- )
-
- # Create right finger extending along +Y, positioned at -X
- right_finger = o3d.geometry.TriangleMesh.create_box(
- width=finger_width, # Thin finger
- height=finger_length, # Extends along Y (finger length direction)
- depth=finger_thickness, # Thin in Z direction
- )
- right_finger.translate(
- [
- -width / 2 - finger_width / 2, # Position at -X (half width from center)
- -finger_length, # Shift so fingertips are at origin
- -finger_thickness / 2, # Center in Z
- ]
- )
-
- # Create base connecting fingers - flat like a stickman body
- base = o3d.geometry.TriangleMesh.create_box(
- width=width + finger_width, # Full width plus finger thickness
- height=finger_thickness, # Flat like fingers (stickman style)
- depth=finger_thickness, # Thin like fingers
- )
- base.translate(
- [
- -width / 2 - finger_width / 2, # Start from left finger position
- -finger_length - finger_thickness, # Behind fingers, adjusted for fingertips at origin
- -finger_thickness / 2, # Center in Z
- ]
- )
-
- # Create handle extending backward - flat stick like stickman arm
- handle = o3d.geometry.TriangleMesh.create_box(
- width=finger_width, # Same width as fingers
- height=handle_length, # Extends backward along Y direction (same plane)
- depth=finger_thickness, # Thin like fingers (same plane)
- )
- handle.translate(
- [
- -finger_width / 2, # Center in X
- -finger_length
- - finger_thickness
- - handle_length, # Extend backward from base, adjusted for fingertips at origin
- -finger_thickness / 2, # Same Z plane as other components
- ]
- )
-
- # Use solid red color for all parts (user changed to red)
- solid_color = [1.0, 0.0, 0.0] # Red color
-
- left_finger.paint_uniform_color(solid_color)
- right_finger.paint_uniform_color(solid_color)
- base.paint_uniform_color(solid_color)
- handle.paint_uniform_color(solid_color)
-
- # Apply transformation to all parts
- left_finger.transform(transform)
- right_finger.transform(transform)
- base.transform(transform)
- handle.transform(transform)
-
- geometries.extend([left_finger, right_finger, base, handle])
-
- return geometries
-
-
-def create_all_gripper_geometries(
- grasp_list: list[dict], # type: ignore[type-arg]
- max_grasps: int = -1,
-) -> list[o3d.geometry.TriangleMesh]:
- """
- Create gripper geometries for multiple grasps.
-
- Args:
- grasp_list: List of grasp dictionaries
- max_grasps: Maximum number of grasps to visualize (-1 for all)
-
- Returns:
- List of all gripper geometries
- """
- all_geometries = []
-
- grasps_to_show = grasp_list if max_grasps < 0 else grasp_list[:max_grasps]
-
- for grasp in grasps_to_show:
- gripper_parts = create_gripper_geometry(grasp)
- all_geometries.extend(gripper_parts)
-
- return all_geometries
-
-
-def draw_grasps_on_image(
- image: np.ndarray, # type: ignore[type-arg]
- grasp_data: dict | dict[int | str, list[dict]] | list[dict], # type: ignore[type-arg]
- camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] # [fx, fy, cx, cy] or 3x3 matrix
- max_grasps: int = -1, # -1 means show all grasps
- finger_length: float = 0.08, # Match 3D gripper
- finger_thickness: float = 0.004, # Match 3D gripper
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Draw fork-like gripper visualizations on the image matching 3D gripper design.
-
- Args:
- image: Base image to draw on
- grasp_data: Can be:
- - A single grasp dict
- - A list of grasp dicts
- - A dictionary mapping object IDs or "scene" to list of grasps
- camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix
- max_grasps: Maximum number of grasps to visualize (-1 for all)
- finger_length: Length of gripper fingers (matches 3D design)
- finger_thickness: Thickness of gripper fingers (matches 3D design)
-
- Returns:
- Image with grasps drawn
- """
- result = image.copy()
-
- # Convert camera intrinsics to 3x3 matrix if needed
- if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4:
- fx, fy, cx, cy = camera_intrinsics
- camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
- else:
- camera_matrix = np.array(camera_intrinsics)
-
- # Convert input to standard format
- if isinstance(grasp_data, dict) and not any(
- key in grasp_data for key in ["scene", 0, 1, 2, 3, 4, 5]
- ):
- # Single grasp
- grasps_to_draw = [(grasp_data, 0)]
- elif isinstance(grasp_data, list):
- # List of grasps
- grasps_to_draw = [(grasp, i) for i, grasp in enumerate(grasp_data)]
- else:
- # Dictionary of grasps by object ID
- grasps_to_draw = []
- for _obj_id, grasps in grasp_data.items():
- for i, grasp in enumerate(grasps):
- grasps_to_draw.append((grasp, i))
-
- # Limit number of grasps if specified
- if max_grasps > 0:
- grasps_to_draw = grasps_to_draw[:max_grasps]
-
- # Define grasp colors (solid red to match 3D design)
- def get_grasp_color(index: int) -> tuple: # type: ignore[type-arg]
- # Use solid red color for all grasps to match 3D design
- return (0, 0, 255) # Red in BGR format for OpenCV
-
- # Draw each grasp
- for grasp, index in grasps_to_draw:
- try:
- color = get_grasp_color(index)
- thickness = max(1, 4 - index // 3)
-
- # Extract grasp parameters (using translation and rotation_matrix)
- if "translation" not in grasp or "rotation_matrix" not in grasp:
- continue
-
- translation = np.array(grasp["translation"])
- rotation_matrix = np.array(grasp["rotation_matrix"])
- width = grasp.get("width", 0.04)
-
- # Match 3D gripper dimensions
- finger_width = 0.006 # Thickness of each finger (matches 3D)
- handle_length = 0.05 # Length of handle extending backward (matches 3D)
-
- # Create gripper geometry in local coordinate system matching 3D design:
- # X-axis = width direction (left/right finger separation)
- # Y-axis = finger length direction (fingers extend along +Y)
- # Z-axis = approach direction (toward object, handle extends along -Z)
- # IMPORTANT: Fingertips should be at origin (translation point)
-
- # Left finger extending along +Y, positioned at +X
- left_finger_points = np.array(
- [
- [
- width / 2 - finger_width / 2, # type: ignore[operator]
- -finger_length,
- -finger_thickness / 2,
- ], # Back left
- [
- width / 2 + finger_width / 2, # type: ignore[operator]
- -finger_length,
- -finger_thickness / 2,
- ], # Back right
- [
- width / 2 + finger_width / 2, # type: ignore[operator]
- 0,
- -finger_thickness / 2,
- ], # Front right (at origin)
- [
- width / 2 - finger_width / 2, # type: ignore[operator]
- 0,
- -finger_thickness / 2,
- ], # Front left (at origin)
- ]
- )
-
- # Right finger extending along +Y, positioned at -X
- right_finger_points = np.array(
- [
- [
- -width / 2 - finger_width / 2, # type: ignore[operator]
- -finger_length,
- -finger_thickness / 2,
- ], # Back left
- [
- -width / 2 + finger_width / 2, # type: ignore[operator]
- -finger_length,
- -finger_thickness / 2,
- ], # Back right
- [
- -width / 2 + finger_width / 2, # type: ignore[operator]
- 0,
- -finger_thickness / 2,
- ], # Front right (at origin)
- [
- -width / 2 - finger_width / 2, # type: ignore[operator]
- 0,
- -finger_thickness / 2,
- ], # Front left (at origin)
- ]
- )
-
- # Base connecting fingers - flat rectangle behind fingers
- base_points = np.array(
- [
- [
- -width / 2 - finger_width / 2, # type: ignore[operator]
- -finger_length - finger_thickness,
- -finger_thickness / 2,
- ], # Back left
- [
- width / 2 + finger_width / 2, # type: ignore[operator]
- -finger_length - finger_thickness,
- -finger_thickness / 2,
- ], # Back right
- [
- width / 2 + finger_width / 2, # type: ignore[operator]
- -finger_length,
- -finger_thickness / 2,
- ], # Front right
- [
- -width / 2 - finger_width / 2, # type: ignore[operator]
- -finger_length,
- -finger_thickness / 2,
- ], # Front left
- ]
- )
-
- # Handle extending backward - thin rectangle
- handle_points = np.array(
- [
- [
- -finger_width / 2,
- -finger_length - finger_thickness - handle_length,
- -finger_thickness / 2,
- ], # Back left
- [
- finger_width / 2,
- -finger_length - finger_thickness - handle_length,
- -finger_thickness / 2,
- ], # Back right
- [
- finger_width / 2,
- -finger_length - finger_thickness,
- -finger_thickness / 2,
- ], # Front right
- [
- -finger_width / 2,
- -finger_length - finger_thickness,
- -finger_thickness / 2,
- ], # Front left
- ]
- )
-
- # Transform all points to world frame
- def transform_points(points): # type: ignore[no-untyped-def]
- # Apply rotation and translation
- world_points = (rotation_matrix @ points.T).T + translation
- return world_points
-
- left_finger_world = transform_points(left_finger_points) # type: ignore[no-untyped-call]
- right_finger_world = transform_points(right_finger_points) # type: ignore[no-untyped-call]
- base_world = transform_points(base_points) # type: ignore[no-untyped-call]
- handle_world = transform_points(handle_points) # type: ignore[no-untyped-call]
-
- # Project to 2D
- left_finger_2d = project_3d_points_to_2d(left_finger_world, camera_matrix)
- right_finger_2d = project_3d_points_to_2d(right_finger_world, camera_matrix)
- base_2d = project_3d_points_to_2d(base_world, camera_matrix)
- handle_2d = project_3d_points_to_2d(handle_world, camera_matrix)
-
- # Draw left finger
- pts = left_finger_2d.astype(np.int32)
- cv2.polylines(result, [pts], True, color, thickness)
-
- # Draw right finger
- pts = right_finger_2d.astype(np.int32)
- cv2.polylines(result, [pts], True, color, thickness)
-
- # Draw base
- pts = base_2d.astype(np.int32)
- cv2.polylines(result, [pts], True, color, thickness)
-
- # Draw handle
- pts = handle_2d.astype(np.int32)
- cv2.polylines(result, [pts], True, color, thickness)
-
- # Draw grasp center (fingertips at origin)
- center_2d = project_3d_points_to_2d(translation.reshape(1, -1), camera_matrix)[0]
- cv2.circle(result, tuple(center_2d.astype(int)), 3, color, -1)
-
- except Exception:
- # Skip this grasp if there's an error
- continue
-
- return result
-
-
-def get_standard_coordinate_transform(): # type: ignore[no-untyped-def]
- """
- Get a standard coordinate transformation matrix for consistent visualization.
-
- This transformation ensures that:
- - X (red) axis points right
- - Y (green) axis points up
- - Z (blue) axis points toward viewer
-
- Returns:
- 4x4 transformation matrix
- """
- # Standard transformation matrix to ensure consistent coordinate frame orientation
- transform = np.array(
- [
- [1, 0, 0, 0], # X points right
- [0, -1, 0, 0], # Y points up (flip from OpenCV to standard)
- [0, 0, -1, 0], # Z points toward viewer (flip depth)
- [0, 0, 0, 1],
- ]
- )
- return transform
-
-
-def visualize_grasps_3d(
- point_cloud: o3d.geometry.PointCloud,
- grasp_list: list[dict], # type: ignore[type-arg]
- max_grasps: int = -1,
-) -> None:
- """
- Visualize grasps in 3D with point cloud.
-
- Args:
- point_cloud: Open3D point cloud
- grasp_list: List of grasp dictionaries
- max_grasps: Maximum number of grasps to visualize
- """
- # Apply standard coordinate transformation
- transform = get_standard_coordinate_transform() # type: ignore[no-untyped-call]
-
- # Transform point cloud
- pc_copy = o3d.geometry.PointCloud(point_cloud)
- pc_copy.transform(transform)
- geometries = [pc_copy]
-
- # Transform gripper geometries
- gripper_geometries = create_all_gripper_geometries(grasp_list, max_grasps)
- for geom in gripper_geometries:
- geom.transform(transform)
- geometries.extend(gripper_geometries)
-
- # Add transformed coordinate frame
- origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
- origin_frame.transform(transform)
- geometries.append(origin_frame)
-
- o3d.visualization.draw_geometries(geometries, window_name="3D Grasp Visualization")
-
-
-def parse_grasp_results(grasps: list[dict]) -> list[dict]: # type: ignore[type-arg]
- """
- Parse grasp results into visualization format.
-
- Args:
- grasps: List of grasp dictionaries
-
- Returns:
- List of dictionaries containing:
- - id: Unique grasp identifier
- - score: Confidence score (float)
- - width: Gripper opening width (float)
- - translation: 3D position [x, y, z]
- - rotation_matrix: 3x3 rotation matrix as nested list
- """
- if not grasps:
- return []
-
- parsed_grasps = []
-
- for i, grasp in enumerate(grasps):
- # Extract data from each grasp
- translation = grasp.get("translation", [0, 0, 0])
- rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3)))
- score = float(grasp.get("score", 0.0))
- width = float(grasp.get("width", 0.08))
-
- parsed_grasp = {
- "id": f"grasp_{i}",
- "score": score,
- "width": width,
- "translation": translation,
- "rotation_matrix": rotation_matrix.tolist(),
- }
- parsed_grasps.append(parsed_grasp)
-
- return parsed_grasps
-
-
-def create_grasp_overlay(
- rgb_image: np.ndarray, # type: ignore[type-arg]
- grasps: list[dict], # type: ignore[type-arg]
- camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg]
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Create grasp visualization overlay on RGB image.
-
- Args:
- rgb_image: RGB input image
- grasps: List of grasp dictionaries in viz format
- camera_intrinsics: Camera parameters
-
- Returns:
- RGB image with grasp overlay
- """
- try:
- bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
-
- result_bgr = draw_grasps_on_image(
- bgr_image,
- grasps,
- camera_intrinsics,
- max_grasps=-1,
- )
- return cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
- except Exception:
- return rgb_image.copy()
diff --git a/dimos/perception/object_detection_stream.py b/dimos/perception/object_detection_stream.py
deleted file mode 100644
index 4d93e3ddd4..0000000000
--- a/dimos/perception/object_detection_stream.py
+++ /dev/null
@@ -1,322 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import numpy as np
-from reactivex import Observable, operators as ops
-
-from dimos.perception.detection2d.yolo_2d_det import ( # type: ignore[import-not-found, import-untyped]
- Yolo2DDetector,
-)
-
-try:
- from dimos.perception.detection2d.detic_2d_det import ( # type: ignore[import-not-found, import-untyped]
- Detic2DDetector,
- )
-
- DETIC_AVAILABLE = True
-except (ModuleNotFoundError, ImportError):
- DETIC_AVAILABLE = False
- Detic2DDetector = None
-from collections.abc import Callable
-from typing import TYPE_CHECKING
-
-from dimos.models.depth.metric3d import Metric3D
-from dimos.perception.common.utils import draw_object_detection_visualization
-from dimos.perception.detection2d.utils import ( # type: ignore[attr-defined]
- calculate_depth_from_bbox,
- calculate_object_size_from_bbox,
- calculate_position_rotation_from_bbox,
-)
-from dimos.types.vector import Vector
-from dimos.utils.logging_config import setup_logger
-from dimos.utils.transform_utils import transform_robot_to_map # type: ignore[attr-defined]
-
-if TYPE_CHECKING:
- from dimos.types.manipulation import ObjectData
-
-# Initialize logger for the ObjectDetectionStream
-logger = setup_logger()
-
-
-class ObjectDetectionStream:
- """
- A stream processor that:
- 1. Detects objects using a Detector (Detic or Yolo)
- 2. Estimates depth using Metric3D
- 3. Calculates 3D position and dimensions using camera intrinsics
- 4. Transforms coordinates to map frame
- 5. Draws bounding boxes and segmentation masks on the frame
-
- Provides a stream of structured object data with position and rotation information.
- """
-
- def __init__( # type: ignore[no-untyped-def]
- self,
- camera_intrinsics=None, # [fx, fy, cx, cy]
- device: str = "cuda",
- gt_depth_scale: float = 1000.0,
- min_confidence: float = 0.7,
- class_filter=None, # Optional list of class names to filter (e.g., ["person", "car"])
- get_pose: Callable | None = None, # type: ignore[type-arg] # Optional function to transform coordinates to map frame
- detector: Detic2DDetector | Yolo2DDetector | None = None,
- video_stream: Observable = None, # type: ignore[assignment, type-arg]
- disable_depth: bool = False, # Flag to disable monocular Metric3D depth estimation
- draw_masks: bool = False, # Flag to enable drawing segmentation masks
- ) -> None:
- """
- Initialize the ObjectDetectionStream.
-
- Args:
- camera_intrinsics: List [fx, fy, cx, cy] with camera parameters
- device: Device to run inference on ("cuda" or "cpu")
- gt_depth_scale: Ground truth depth scale for Metric3D
- min_confidence: Minimum confidence for detections
- class_filter: Optional list of class names to filter
- get_pose: Optional function to transform pose to map coordinates
- detector: Optional detector instance (Detic or Yolo)
- video_stream: Observable of video frames to process (if provided, returns a stream immediately)
- disable_depth: Flag to disable monocular Metric3D depth estimation
- draw_masks: Flag to enable drawing segmentation masks
- """
- self.min_confidence = min_confidence
- self.class_filter = class_filter
- self.get_pose = get_pose
- self.disable_depth = disable_depth
- self.draw_masks = draw_masks
- # Initialize object detector
- if detector is not None:
- self.detector = detector
- else:
- if DETIC_AVAILABLE:
- try:
- self.detector = Detic2DDetector(vocabulary=None, threshold=min_confidence)
- logger.info("Using Detic2DDetector")
- except Exception as e:
- logger.warning(
- f"Failed to initialize Detic2DDetector: {e}. Falling back to Yolo2DDetector."
- )
- self.detector = Yolo2DDetector()
- else:
- logger.info("Detic not available. Using Yolo2DDetector.")
- self.detector = Yolo2DDetector()
- # Set up camera intrinsics
- self.camera_intrinsics = camera_intrinsics
-
- # Initialize depth estimation model
- self.depth_model = None
- if not disable_depth:
- try:
- self.depth_model = Metric3D(gt_depth_scale=gt_depth_scale)
-
- if camera_intrinsics is not None:
- self.depth_model.update_intrinsic(camera_intrinsics) # type: ignore[no-untyped-call]
-
- # Create 3x3 camera matrix for calculations
- fx, fy, cx, cy = camera_intrinsics
- self.camera_matrix = np.array(
- [[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32
- )
- else:
- raise ValueError("camera_intrinsics must be provided")
-
- logger.info("Depth estimation enabled with Metric3D")
- except Exception as e:
- logger.warning(f"Failed to initialize Metric3D depth model: {e}")
- logger.warning("Falling back to disable_depth=True mode")
- self.disable_depth = True
- self.depth_model = None
- else:
- logger.info("Depth estimation disabled")
-
- # If video_stream is provided, create and store the stream immediately
- self.stream = None
- if video_stream is not None:
- self.stream = self.create_stream(video_stream)
-
- def create_stream(self, video_stream: Observable) -> Observable: # type: ignore[type-arg]
- """
- Create an Observable stream of object data from a video stream.
-
- Args:
- video_stream: Observable that emits video frames
-
- Returns:
- Observable that emits dictionaries containing object data
- with position and rotation information
- """
-
- def process_frame(frame): # type: ignore[no-untyped-def]
- # TODO: More modular detector output interface
- bboxes, track_ids, class_ids, confidences, names, *mask_data = ( # type: ignore[misc]
- *self.detector.process_image(frame),
- [],
- )
-
- masks = (
- mask_data[0] # type: ignore[has-type]
- if mask_data and len(mask_data[0]) == len(bboxes) # type: ignore[has-type]
- else [None] * len(bboxes) # type: ignore[has-type]
- )
-
- # Create visualization
- viz_frame = frame.copy()
-
- # Process detections
- objects = []
- if not self.disable_depth:
- depth_map = self.depth_model.infer_depth(frame) # type: ignore[union-attr]
- depth_map = np.array(depth_map)
- else:
- depth_map = None
-
- for i, bbox in enumerate(bboxes): # type: ignore[has-type]
- # Skip if confidence is too low
- if i < len(confidences) and confidences[i] < self.min_confidence: # type: ignore[has-type]
- continue
-
- # Skip if class filter is active and class not in filter
- class_name = names[i] if i < len(names) else None # type: ignore[has-type]
- if self.class_filter and class_name not in self.class_filter:
- continue
-
- if not self.disable_depth and depth_map is not None:
- # Get depth for this object
- depth = calculate_depth_from_bbox(depth_map, bbox) # type: ignore[no-untyped-call]
- if depth is None:
- # Skip objects with invalid depth
- continue
- # Calculate object position and rotation
- position, rotation = calculate_position_rotation_from_bbox(
- bbox, depth, self.camera_intrinsics
- )
- # Get object dimensions
- width, height = calculate_object_size_from_bbox(
- bbox, depth, self.camera_intrinsics
- )
-
- # Transform to map frame if a transform function is provided
- try:
- if self.get_pose:
- # position and rotation are already Vector objects, no need to convert
- robot_pose = self.get_pose()
- position, rotation = transform_robot_to_map(
- robot_pose["position"], robot_pose["rotation"], position, rotation
- )
- except Exception as e:
- logger.error(f"Error transforming to map frame: {e}")
- position, rotation = position, rotation
-
- else:
- depth = -1
- position = Vector(0, 0, 0) # type: ignore[arg-type]
- rotation = Vector(0, 0, 0) # type: ignore[arg-type]
- width = -1
- height = -1
-
- # Create a properly typed ObjectData instance
- object_data: ObjectData = {
- "object_id": track_ids[i] if i < len(track_ids) else -1, # type: ignore[has-type]
- "bbox": bbox,
- "depth": depth,
- "confidence": confidences[i] if i < len(confidences) else None, # type: ignore[has-type, typeddict-item]
- "class_id": class_ids[i] if i < len(class_ids) else None, # type: ignore[has-type, typeddict-item]
- "label": class_name, # type: ignore[typeddict-item]
- "position": position,
- "rotation": rotation,
- "size": {"width": width, "height": height},
- "segmentation_mask": masks[i],
- }
-
- objects.append(object_data)
-
- # Create visualization using common function
- viz_frame = draw_object_detection_visualization(
- viz_frame, objects, draw_masks=self.draw_masks, font_scale=1.5
- )
-
- return {"frame": frame, "viz_frame": viz_frame, "objects": objects}
-
- self.stream = video_stream.pipe(ops.map(process_frame))
-
- return self.stream
-
- def get_stream(self): # type: ignore[no-untyped-def]
- """
- Returns the current detection stream if available.
- Creates a new one with the provided video_stream if not already created.
-
- Returns:
- Observable: The reactive stream of detection results
- """
- if self.stream is None:
- raise ValueError(
- "Stream not initialized. Either provide a video_stream during initialization or call create_stream first."
- )
- return self.stream
-
- def get_formatted_stream(self): # type: ignore[no-untyped-def]
- """
- Returns a formatted stream of object detection data for better readability.
- This is especially useful for LLMs like Claude that need structured text input.
-
- Returns:
- Observable: A stream of formatted string representations of object data
- """
- if self.stream is None:
- raise ValueError(
- "Stream not initialized. Either provide a video_stream during initialization or call create_stream first."
- )
-
- def format_detection_data(result): # type: ignore[no-untyped-def]
- # Extract objects from result
- objects = result.get("objects", [])
-
- if not objects:
- return "No objects detected."
-
- formatted_data = "[DETECTED OBJECTS]\n"
- try:
- for i, obj in enumerate(objects):
- pos = obj["position"]
- rot = obj["rotation"]
- size = obj["size"]
- bbox = obj["bbox"]
-
- # Format each object with a multiline f-string for better readability
- bbox_str = f"[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]"
- formatted_data += (
- f"Object {i + 1}: {obj['label']}\n"
- f" ID: {obj['object_id']}\n"
- f" Confidence: {obj['confidence']:.2f}\n"
- f" Position: x={pos.x:.2f}m, y={pos.y:.2f}m, z={pos.z:.2f}m\n"
- f" Rotation: yaw={rot.z:.2f} rad\n"
- f" Size: width={size['width']:.2f}m, height={size['height']:.2f}m\n"
- f" Depth: {obj['depth']:.2f}m\n"
- f" Bounding box: {bbox_str}\n"
- "----------------------------------\n"
- )
- except Exception as e:
- logger.warning(f"Error formatting object {i}: {e}")
- formatted_data += f"Object {i + 1}: [Error formatting data]"
- formatted_data += "\n----------------------------------\n"
-
- return formatted_data
-
- # Return a new stream with the formatter applied
- return self.stream.pipe(ops.map(format_detection_data))
-
- def cleanup(self) -> None:
- """Clean up resources."""
- pass
diff --git a/dimos/perception/object_scene_registration.py b/dimos/perception/object_scene_registration.py
new file mode 100644
index 0000000000..c21e31bf33
--- /dev/null
+++ b/dimos/perception/object_scene_registration.py
@@ -0,0 +1,247 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from typing import Any
+
+import numpy as np
+from numpy.typing import NDArray
+
+from dimos.core import In, Out, rpc
+from dimos.core.skill_module import SkillModule
+from dimos.msgs.foxglove_msgs import ImageAnnotations
+from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
+from dimos.msgs.std_msgs import Header
+from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray
+from dimos.perception.detection.detectors.yoloe import Yoloe2DDetector, YoloePromptMode
+from dimos.perception.detection.objectDB import ObjectDB
+from dimos.perception.detection.type import ImageDetections2D
+from dimos.perception.detection.type.detection3d.object import (
+ Object,
+ aggregate_pointclouds,
+ to_detection3d_array,
+)
+from dimos.protocol.skill.skill import skill
+from dimos.types.timestamped import align_timestamped
+from dimos.utils.logging_config import setup_logger
+from dimos.utils.reactive import backpressure
+
+logger = setup_logger()
+
+
+class ObjectSceneRegistrationModule(SkillModule):
+ """Module for detecting objects in camera images using YOLO-E with 2D and 3D detection."""
+
+ color_image: In[Image]
+ depth_image: In[Image]
+ camera_info: In[CameraInfo]
+
+ detections_2d: Out[Detection2DArray]
+ detections_3d: Out[Detection3DArray]
+ overlay: Out[ImageAnnotations]
+ pointcloud: Out[PointCloud2]
+
+ _detector: Yoloe2DDetector | None = None
+ _camera_info: CameraInfo | None = None
+ _object_db: ObjectDB
+
+ def __init__(
+ self,
+ target_frame: str = "map",
+ prompt_mode: YoloePromptMode = YoloePromptMode.LRPC,
+ ) -> None:
+ super().__init__()
+ self._target_frame = target_frame
+ self._prompt_mode = prompt_mode
+ self._object_db = ObjectDB()
+
+ @rpc
+ def start(self) -> None:
+ super().start()
+
+ if self._prompt_mode == YoloePromptMode.LRPC:
+ model_name = "yoloe-11l-seg-pf.pt"
+ else:
+ model_name = "yoloe-11l-seg.pt"
+
+ self._detector = Yoloe2DDetector(
+ model_name=model_name,
+ prompt_mode=self._prompt_mode,
+ )
+
+ self.camera_info.subscribe(lambda msg: setattr(self, "_camera_info", msg))
+
+ aligned_frames = align_timestamped(
+ self.color_image.observable(), # type: ignore[no-untyped-call]
+ self.depth_image.observable(), # type: ignore[no-untyped-call]
+ buffer_size=2.0,
+ match_tolerance=0.1,
+ )
+ backpressure(aligned_frames).subscribe(self._on_aligned_frames)
+
+ @rpc
+ def stop(self) -> None:
+ """Stop the module and clean up resources."""
+
+ if self._detector:
+ self._detector.stop()
+ self._detector = None
+
+ self._object_db.clear()
+
+ logger.info("ObjectSceneRegistrationModule stopped")
+ super().stop()
+
+ @rpc
+ def set_prompts(
+ self,
+ text: list[str] | None = None,
+ bboxes: NDArray[np.float64] | None = None,
+ ) -> None:
+ """Set prompts for detection. Provide either text or bboxes, not both."""
+ if self._detector is not None:
+ self._detector.set_prompts(text=text, bboxes=bboxes)
+
+ @rpc
+ def select_object(self, track_id: int) -> dict[str, Any] | None:
+ """Get object data by track_id and promote to permanent."""
+ for obj in self._object_db.get_all_objects():
+ if obj.track_id == track_id:
+ self._object_db.promote(obj.object_id)
+ return obj.to_dict()
+ return None
+
+ @rpc
+ def get_object_track_ids(self) -> list[int]:
+ """Get track_ids of all permanent objects."""
+ return [obj.track_id for obj in self._object_db.get_all_objects()]
+
+ @skill()
+ def detect(self, *prompts: str) -> str:
+ """Detect objects matching the given text prompts. Returns track_ids after 2 seconds of detection.
+
+ Do NOT call this tool multiple times for one query. Pass all objects in a single call.
+ For example, to detect a cup and mouse, call detect("cup", "mouse") not detect("cup") then detect("mouse").
+
+ Args:
+ prompts (str): Text descriptions of objects to detect (e.g., "person", "car", "dog")
+
+ Returns:
+ str: A message containing the track_ids of detected objects
+
+ Example:
+ detect("person", "car", "dog")
+ detect("person")
+ """
+ if not prompts:
+ return "No prompts provided."
+ if self._detector is None:
+ return "Detector not initialized."
+
+ self._detector.set_prompts(text=list(prompts))
+ time.sleep(2.0)
+
+ track_ids = self.get_object_track_ids()
+ if not track_ids:
+ return "No objects detected."
+ return f"Detected objects with track_ids: {track_ids}"
+
+ @skill()
+ def select(self, track_id: int) -> str:
+ """Select an object by track_id and promote it to permanent.
+
+ Example:
+ select(5)
+ """
+ result = self.select_object(track_id)
+ if result is None:
+ return f"No object found with track_id {track_id}."
+ return f"Selected object {track_id}: {result['name']}"
+
+ def _on_aligned_frames(self, frames) -> None: # type: ignore[no-untyped-def]
+ color_msg, depth_msg = frames
+ self._process_images(color_msg, depth_msg)
+
+ def _process_images(self, color_msg: Image, depth_msg: Image) -> None:
+ """Process synchronized color and depth images (runs in background thread)."""
+ if not self._detector or not self._camera_info:
+ return
+
+ color_image = color_msg
+ depth_image = depth_msg.to_depth_meters()
+
+ # Run 2D detection
+ detections_2d: ImageDetections2D[Any] = self._detector.process_image(color_image)
+
+ detections_2d_msg = Detection2DArray(
+ detections_length=len(detections_2d.detections),
+ header=Header(color_image.ts, color_image.frame_id or ""),
+ detections=[det.to_ros_detection2d() for det in detections_2d.detections],
+ )
+ self.detections_2d.publish(detections_2d_msg)
+
+ overlay_annotations = detections_2d.to_foxglove_annotations()
+ self.overlay.publish(overlay_annotations)
+
+ # Process 3D detections
+ self._process_3d_detections(detections_2d, color_image, depth_image)
+
+ def _process_3d_detections(
+ self,
+ detections_2d: ImageDetections2D[Any],
+ color_image: Image,
+ depth_image: Image,
+ ) -> None:
+ """Convert 2D detections to 3D and publish."""
+ if self._camera_info is None:
+ return
+
+ # Look up transform from camera frame to target frame (e.g., map)
+ camera_transform = None
+ if self._target_frame != color_image.frame_id:
+ camera_transform = self.tf.get(
+ self._target_frame,
+ color_image.frame_id,
+ color_image.ts,
+ 0.1,
+ )
+ if camera_transform is None:
+ logger.warning("Failed to lookup transform from camera frame to target frame")
+ return
+
+ objects = Object.from_2d_to_list(
+ detections_2d=detections_2d,
+ color_image=color_image,
+ depth_image=depth_image,
+ camera_info=self._camera_info,
+ camera_transform=camera_transform,
+ )
+ if not objects:
+ return
+
+ # Add objects to spatial memory database
+ objects = self._object_db.add_objects(objects)
+
+ detections_3d = to_detection3d_array(objects)
+ self.detections_3d.publish(detections_3d)
+
+ objects_for_pc = self._object_db.get_objects()
+ aggregated_pc = aggregate_pointclouds(objects_for_pc)
+ self.pointcloud.publish(aggregated_pc)
+ return
+
+
+object_scene_registration_module = ObjectSceneRegistrationModule.blueprint
+
+__all__ = ["ObjectSceneRegistrationModule", "object_scene_registration_module"]
diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py
index 9260003ce2..54a5873435 100644
--- a/dimos/perception/object_tracker.py
+++ b/dimos/perception/object_tracker.py
@@ -28,7 +28,6 @@
from reactivex.disposable import Disposable
from dimos.core import In, Module, ModuleConfig, Out, rpc
-from dimos.manipulation.visual_servoing.utils import visualize_detections_3d
from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3
from dimos.msgs.sensor_msgs import (
CameraInfo,
@@ -540,10 +539,13 @@ def _process_tracking(self) -> None:
y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2
bbox_2d = [[x1, y1, x2, y2]]
- # Create visualization
- viz_image = visualize_detections_3d(
- frame, detections_3d, show_coordinates=True, bboxes_2d=bbox_2d
- )
+ # Use frame directly for visualization
+ viz_image = frame.copy()
+
+ # Draw bounding boxes
+ for bbox in bbox_2d:
+ x1, y1, x2, y2 = map(int, bbox)
+ cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Overlay REID feature matches if available
if self.last_good_matches and self.last_roi_kps and self.last_roi_bbox:
diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py
index 22846e1e2f..fa6361ac65 100644
--- a/dimos/perception/object_tracker_3d.py
+++ b/dimos/perception/object_tracker_3d.py
@@ -14,6 +14,7 @@
# Import LCM messages
+import cv2
from dimos_lcm.sensor_msgs import CameraInfo
from dimos_lcm.vision_msgs import (
Detection3D,
@@ -22,7 +23,6 @@
import numpy as np
from dimos.core import In, Out, rpc
-from dimos.manipulation.visual_servoing.utils import visualize_detections_3d
from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3
from dimos.msgs.sensor_msgs import Image, ImageFormat
from dimos.msgs.std_msgs import Header
@@ -145,10 +145,13 @@ def _process_tracking(self) -> None:
y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2
bbox_2d = [[x1, y1, x2, y2]]
- # Create 3D visualization
- viz_image = visualize_detections_3d(
- frame, detection_3d.detections, show_coordinates=True, bboxes_2d=bbox_2d
- )
+ # Use frame directly for visualization
+ viz_image = frame.copy()
+
+ # Draw bounding boxes
+ for bbox in bbox_2d:
+ x1, y1, x2, y2 = map(int, bbox)
+ cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Overlay Re-ID matches
if self.last_good_matches and self.last_roi_kps and self.last_roi_bbox:
diff --git a/dimos/perception/person_tracker.py b/dimos/perception/person_tracker.py
deleted file mode 100644
index a138467850..0000000000
--- a/dimos/perception/person_tracker.py
+++ /dev/null
@@ -1,262 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import cv2
-import numpy as np
-from reactivex import Observable, interval, operators as ops
-from reactivex.disposable import Disposable
-
-from dimos.core import In, Module, Out, rpc
-from dimos.msgs.sensor_msgs import Image
-from dimos.perception.common.ibvs import PersonDistanceEstimator
-from dimos.perception.detection2d.utils import filter_detections
-from dimos.perception.detection2d.yolo_2d_det import ( # type: ignore[import-not-found, import-untyped]
- Yolo2DDetector,
-)
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class PersonTrackingStream(Module):
- """Module for person tracking with LCM input/output."""
-
- # LCM inputs
- video: In[Image]
-
- # LCM outputs
- tracking_data: Out[dict] # type: ignore[type-arg]
-
- def __init__( # type: ignore[no-untyped-def]
- self,
- camera_intrinsics=None,
- camera_pitch: float = 0.0,
- camera_height: float = 1.0,
- ) -> None:
- """
- Initialize a person tracking stream using Yolo2DDetector and PersonDistanceEstimator.
-
- Args:
- camera_intrinsics: List in format [fx, fy, cx, cy] where:
- - fx: Focal length in x direction (pixels)
- - fy: Focal length in y direction (pixels)
- - cx: Principal point x-coordinate (pixels)
- - cy: Principal point y-coordinate (pixels)
- camera_pitch: Camera pitch angle in radians (positive is up)
- camera_height: Height of the camera from the ground in meters
- """
- # Call parent Module init
- super().__init__()
-
- self.camera_intrinsics = camera_intrinsics
- self.camera_pitch = camera_pitch
- self.camera_height = camera_height
-
- self.detector = Yolo2DDetector()
-
- # Initialize distance estimator
- if camera_intrinsics is None:
- raise ValueError("Camera intrinsics are required for distance estimation")
-
- # Validate camera intrinsics format [fx, fy, cx, cy]
- if (
- not isinstance(camera_intrinsics, list | tuple | np.ndarray)
- or len(camera_intrinsics) != 4
- ):
- raise ValueError("Camera intrinsics must be provided as [fx, fy, cx, cy]")
-
- # Convert [fx, fy, cx, cy] to 3x3 camera matrix
- fx, fy, cx, cy = camera_intrinsics
- K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
-
- self.distance_estimator = PersonDistanceEstimator(
- K=K, camera_pitch=camera_pitch, camera_height=camera_height
- )
-
- # For tracking latest frame data
- self._latest_frame: np.ndarray | None = None # type: ignore[type-arg]
- self._process_interval = 0.1 # Process at 10Hz
-
- # Tracking state - starts disabled
- self._tracking_enabled = False
-
- @rpc
- def start(self) -> None:
- """Start the person tracking module and subscribe to LCM streams."""
-
- super().start()
-
- # Subscribe to video stream
- def set_video(image_msg: Image) -> None:
- if hasattr(image_msg, "data"):
- self._latest_frame = image_msg.data
- else:
- logger.warning("Received image message without data attribute")
-
- unsub = self.video.subscribe(set_video)
- self._disposables.add(Disposable(unsub))
-
- # Start periodic processing
- unsub = interval(self._process_interval).subscribe(lambda _: self._process_frame()) # type: ignore[assignment]
- self._disposables.add(unsub) # type: ignore[arg-type]
-
- logger.info("PersonTracking module started and subscribed to LCM streams")
-
- @rpc
- def stop(self) -> None:
- super().stop()
-
- def _process_frame(self) -> None:
- """Process the latest frame if available."""
- if self._latest_frame is None:
- return
-
- # Only process and publish if tracking is enabled
- if not self._tracking_enabled:
- return
-
- # Process frame through tracking pipeline
- result = self._process_tracking(self._latest_frame) # type: ignore[no-untyped-call]
-
- # Publish result to LCM
- if result:
- self.tracking_data.publish(result)
-
- def _process_tracking(self, frame): # type: ignore[no-untyped-def]
- """Process a single frame for person tracking."""
- # Detect people in the frame
- bboxes, track_ids, class_ids, confidences, names = self.detector.process_image(frame)
-
- # Filter to keep only person detections using filter_detections
- (
- filtered_bboxes,
- filtered_track_ids,
- filtered_class_ids,
- filtered_confidences,
- filtered_names,
- ) = filter_detections(
- bboxes,
- track_ids,
- class_ids,
- confidences,
- names,
- class_filter=[0], # 0 is the class_id for person
- name_filter=["person"],
- )
-
- # Create visualization
- viz_frame = self.detector.visualize_results(
- frame,
- filtered_bboxes,
- filtered_track_ids,
- filtered_class_ids,
- filtered_confidences,
- filtered_names,
- )
-
- # Calculate distance and angle for each person
- targets = []
- for i, bbox in enumerate(filtered_bboxes):
- target_data = {
- "target_id": filtered_track_ids[i] if i < len(filtered_track_ids) else -1,
- "bbox": bbox,
- "confidence": filtered_confidences[i] if i < len(filtered_confidences) else None,
- }
-
- distance, angle = self.distance_estimator.estimate_distance_angle(bbox)
- target_data["distance"] = distance
- target_data["angle"] = angle
-
- # Add text to visualization
- _x1, y1, x2, _y2 = map(int, bbox)
- dist_text = f"{distance:.2f}m, {np.rad2deg(angle):.1f} deg"
-
- # Add black background for better visibility
- text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
- # Position at top-right corner
- cv2.rectangle(
- viz_frame, (x2 - text_size[0], y1 - text_size[1] - 5), (x2, y1), (0, 0, 0), -1
- )
-
- # Draw text in white at top-right
- cv2.putText(
- viz_frame,
- dist_text,
- (x2 - text_size[0], y1 - 5),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (255, 255, 255),
- 2,
- )
-
- targets.append(target_data)
-
- # Create the result dictionary
- return {"frame": frame, "viz_frame": viz_frame, "targets": targets}
-
- @rpc
- def enable_tracking(self) -> bool:
- """Enable person tracking.
-
- Returns:
- bool: True if tracking was enabled successfully
- """
- self._tracking_enabled = True
- logger.info("Person tracking enabled")
- return True
-
- @rpc
- def disable_tracking(self) -> bool:
- """Disable person tracking.
-
- Returns:
- bool: True if tracking was disabled successfully
- """
- self._tracking_enabled = False
- logger.info("Person tracking disabled")
- return True
-
- @rpc
- def is_tracking_enabled(self) -> bool:
- """Check if tracking is currently enabled.
-
- Returns:
- bool: True if tracking is enabled
- """
- return self._tracking_enabled
-
- @rpc
- def get_tracking_data(self) -> dict: # type: ignore[type-arg]
- """Get the latest tracking data.
-
- Returns:
- Dictionary containing tracking results
- """
- if self._latest_frame is not None:
- return self._process_tracking(self._latest_frame) # type: ignore[no-any-return, no-untyped-call]
- return {"frame": None, "viz_frame": None, "targets": []}
-
- def create_stream(self, video_stream: Observable) -> Observable: # type: ignore[type-arg]
- """
- Create an Observable stream of person tracking results from a video stream.
-
- Args:
- video_stream: Observable that emits video frames
-
- Returns:
- Observable that emits dictionaries containing tracking results and visualizations
- """
-
- return video_stream.pipe(ops.map(self._process_tracking))
diff --git a/dimos/perception/pointcloud/__init__.py b/dimos/perception/pointcloud/__init__.py
deleted file mode 100644
index a380e2aadf..0000000000
--- a/dimos/perception/pointcloud/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .cuboid_fit import *
-from .pointcloud_filtering import *
-from .utils import *
diff --git a/dimos/perception/pointcloud/cuboid_fit.py b/dimos/perception/pointcloud/cuboid_fit.py
deleted file mode 100644
index dfec2d9297..0000000000
--- a/dimos/perception/pointcloud/cuboid_fit.py
+++ /dev/null
@@ -1,420 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import cv2
-import numpy as np
-import open3d as o3d # type: ignore[import-untyped]
-
-
-def fit_cuboid(
- points: np.ndarray | o3d.geometry.PointCloud, # type: ignore[type-arg]
- method: str = "minimal",
-) -> dict | None: # type: ignore[type-arg]
- """
- Fit a cuboid to a point cloud using Open3D's built-in methods.
-
- Args:
- points: Nx3 array of points or Open3D PointCloud
- method: Fitting method:
- - 'minimal': Minimal oriented bounding box (best fit)
- - 'oriented': PCA-based oriented bounding box
- - 'axis_aligned': Axis-aligned bounding box
-
- Returns:
- Dictionary containing:
- - center: 3D center point
- - dimensions: 3D dimensions (extent)
- - rotation: 3x3 rotation matrix
- - error: Fitting error
- - bounding_box: Open3D OrientedBoundingBox object
- Returns None if insufficient points or fitting fails.
-
- Raises:
- ValueError: If method is invalid or inputs are malformed
- """
- # Validate method
- valid_methods = ["minimal", "oriented", "axis_aligned"]
- if method not in valid_methods:
- raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
-
- # Convert to point cloud if needed
- if isinstance(points, np.ndarray):
- points = np.asarray(points)
- if len(points.shape) != 2 or points.shape[1] != 3:
- raise ValueError(f"points array must be Nx3, got shape {points.shape}")
- if len(points) < 4:
- return None
-
- pcd = o3d.geometry.PointCloud()
- pcd.points = o3d.utility.Vector3dVector(points)
- elif isinstance(points, o3d.geometry.PointCloud):
- pcd = points
- points = np.asarray(pcd.points)
- if len(points) < 4:
- return None
- else:
- raise ValueError(f"points must be numpy array or Open3D PointCloud, got {type(points)}")
-
- try:
- # Get bounding box based on method
- if method == "minimal":
- obb = pcd.get_minimal_oriented_bounding_box(robust=True)
- elif method == "oriented":
- obb = pcd.get_oriented_bounding_box(robust=True)
- elif method == "axis_aligned":
- # Convert axis-aligned to oriented format for consistency
- aabb = pcd.get_axis_aligned_bounding_box()
- obb = o3d.geometry.OrientedBoundingBox()
- obb.center = aabb.get_center()
- obb.extent = aabb.get_extent()
- obb.R = np.eye(3) # Identity rotation for axis-aligned
-
- # Extract parameters
- center = np.asarray(obb.center)
- dimensions = np.asarray(obb.extent)
- rotation = np.asarray(obb.R)
-
- # Calculate fitting error
- error = _compute_fitting_error(points, center, dimensions, rotation)
-
- return {
- "center": center,
- "dimensions": dimensions,
- "rotation": rotation,
- "error": error,
- "bounding_box": obb,
- "method": method,
- }
-
- except Exception as e:
- # Log error but don't crash - return None for graceful handling
- print(f"Warning: Cuboid fitting failed with method '{method}': {e}")
- return None
-
-
-def fit_cuboid_simple(points: np.ndarray | o3d.geometry.PointCloud) -> dict | None: # type: ignore[type-arg]
- """
- Simple wrapper for minimal oriented bounding box fitting.
-
- Args:
- points: Nx3 array of points or Open3D PointCloud
-
- Returns:
- Dictionary with center, dimensions, rotation, and bounding_box,
- or None if insufficient points
- """
- return fit_cuboid(points, method="minimal")
-
-
-def _compute_fitting_error(
- points: np.ndarray, # type: ignore[type-arg]
- center: np.ndarray, # type: ignore[type-arg]
- dimensions: np.ndarray, # type: ignore[type-arg]
- rotation: np.ndarray, # type: ignore[type-arg]
-) -> float:
- """
- Compute fitting error as mean squared distance from points to cuboid surface.
-
- Args:
- points: Nx3 array of points
- center: 3D center point
- dimensions: 3D dimensions
- rotation: 3x3 rotation matrix
-
- Returns:
- Mean squared error
- """
- if len(points) == 0:
- return 0.0
-
- # Transform points to local coordinates
- local_points = (points - center) @ rotation
- half_dims = dimensions / 2
-
- # Calculate distance to cuboid surface
- dx = np.abs(local_points[:, 0]) - half_dims[0]
- dy = np.abs(local_points[:, 1]) - half_dims[1]
- dz = np.abs(local_points[:, 2]) - half_dims[2]
-
- # Points outside: distance to nearest face
- # Points inside: negative distance to nearest face
- outside_dist = np.sqrt(np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2)
- inside_dist = np.minimum(np.minimum(dx, dy), dz)
- distances = np.where((dx > 0) | (dy > 0) | (dz > 0), outside_dist, -inside_dist)
-
- return float(np.mean(distances**2))
-
-
-def get_cuboid_corners(
- center: np.ndarray, # type: ignore[type-arg]
- dimensions: np.ndarray, # type: ignore[type-arg]
- rotation: np.ndarray, # type: ignore[type-arg]
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Get the 8 corners of a cuboid.
-
- Args:
- center: 3D center point
- dimensions: 3D dimensions
- rotation: 3x3 rotation matrix
-
- Returns:
- 8x3 array of corner coordinates
- """
- half_dims = dimensions / 2
- corners_local = (
- np.array(
- [
- [-1, -1, -1], # 0: left bottom back
- [-1, -1, 1], # 1: left bottom front
- [-1, 1, -1], # 2: left top back
- [-1, 1, 1], # 3: left top front
- [1, -1, -1], # 4: right bottom back
- [1, -1, 1], # 5: right bottom front
- [1, 1, -1], # 6: right top back
- [1, 1, 1], # 7: right top front
- ]
- )
- * half_dims
- )
-
- # Apply rotation and translation
- return corners_local @ rotation.T + center # type: ignore[no-any-return]
-
-
-def visualize_cuboid_on_image(
- image: np.ndarray, # type: ignore[type-arg]
- cuboid_params: dict, # type: ignore[type-arg]
- camera_matrix: np.ndarray, # type: ignore[type-arg]
- extrinsic_rotation: np.ndarray | None = None, # type: ignore[type-arg]
- extrinsic_translation: np.ndarray | None = None, # type: ignore[type-arg]
- color: tuple[int, int, int] = (0, 255, 0),
- thickness: int = 2,
- show_dimensions: bool = True,
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Draw a fitted cuboid on an image using camera projection.
-
- Args:
- image: Input image to draw on
- cuboid_params: Dictionary containing cuboid parameters
- camera_matrix: Camera intrinsic matrix (3x3)
- extrinsic_rotation: Optional external rotation (3x3)
- extrinsic_translation: Optional external translation (3x1)
- color: Line color as (B, G, R) tuple
- thickness: Line thickness
- show_dimensions: Whether to display dimension text
-
- Returns:
- Image with cuboid visualization
-
- Raises:
- ValueError: If required parameters are missing or invalid
- """
- # Validate inputs
- required_keys = ["center", "dimensions", "rotation"]
- if not all(key in cuboid_params for key in required_keys):
- raise ValueError(f"cuboid_params must contain keys: {required_keys}")
-
- if camera_matrix.shape != (3, 3):
- raise ValueError(f"camera_matrix must be 3x3, got {camera_matrix.shape}")
-
- # Get corners in world coordinates
- corners = get_cuboid_corners(
- cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"]
- )
-
- # Transform corners if extrinsic parameters are provided
- if extrinsic_rotation is not None and extrinsic_translation is not None:
- if extrinsic_rotation.shape != (3, 3):
- raise ValueError(f"extrinsic_rotation must be 3x3, got {extrinsic_rotation.shape}")
- if extrinsic_translation.shape not in [(3,), (3, 1)]:
- raise ValueError(
- f"extrinsic_translation must be (3,) or (3,1), got {extrinsic_translation.shape}"
- )
-
- extrinsic_translation = extrinsic_translation.flatten()
- corners = (extrinsic_rotation @ corners.T).T + extrinsic_translation
-
- try:
- # Project 3D corners to image coordinates
- corners_img, _ = cv2.projectPoints( # type: ignore[call-overload]
- corners.astype(np.float32),
- np.zeros(3),
- np.zeros(3), # No additional rotation/translation
- camera_matrix.astype(np.float32),
- None, # No distortion
- )
- corners_img = corners_img.reshape(-1, 2).astype(int)
-
- # Check if corners are within image bounds
- h, w = image.shape[:2]
- valid_corners = (
- (corners_img[:, 0] >= 0)
- & (corners_img[:, 0] < w)
- & (corners_img[:, 1] >= 0)
- & (corners_img[:, 1] < h)
- )
-
- if not np.any(valid_corners):
- print("Warning: All cuboid corners are outside image bounds")
- return image.copy()
-
- except Exception as e:
- print(f"Warning: Failed to project cuboid corners: {e}")
- return image.copy()
-
- # Define edges for wireframe visualization
- edges = [
- # Bottom face
- (0, 1),
- (1, 5),
- (5, 4),
- (4, 0),
- # Top face
- (2, 3),
- (3, 7),
- (7, 6),
- (6, 2),
- # Vertical edges
- (0, 2),
- (1, 3),
- (5, 7),
- (4, 6),
- ]
-
- # Draw edges
- vis_img = image.copy()
- for i, j in edges:
- # Only draw edge if both corners are valid
- if valid_corners[i] and valid_corners[j]:
- cv2.line(vis_img, tuple(corners_img[i]), tuple(corners_img[j]), color, thickness)
-
- # Add dimension text if requested
- if show_dimensions and np.any(valid_corners):
- dims = cuboid_params["dimensions"]
- dim_text = f"Dims: {dims[0]:.3f} x {dims[1]:.3f} x {dims[2]:.3f}"
-
- # Find a good position for text (top-left of image)
- text_pos = (10, 30)
- font_scale = 0.7
-
- # Add background rectangle for better readability
- text_size = cv2.getTextSize(dim_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2)[0]
- cv2.rectangle(
- vis_img,
- (text_pos[0] - 5, text_pos[1] - text_size[1] - 5),
- (text_pos[0] + text_size[0] + 5, text_pos[1] + 5),
- (0, 0, 0),
- -1,
- )
-
- cv2.putText(vis_img, dim_text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, 2)
-
- return vis_img
-
-
-def compute_cuboid_volume(cuboid_params: dict) -> float: # type: ignore[type-arg]
- """
- Compute the volume of a cuboid.
-
- Args:
- cuboid_params: Dictionary containing cuboid parameters
-
- Returns:
- Volume in cubic units
- """
- if "dimensions" not in cuboid_params:
- raise ValueError("cuboid_params must contain 'dimensions' key")
-
- dims = cuboid_params["dimensions"]
- return float(np.prod(dims))
-
-
-def compute_cuboid_surface_area(cuboid_params: dict) -> float: # type: ignore[type-arg]
- """
- Compute the surface area of a cuboid.
-
- Args:
- cuboid_params: Dictionary containing cuboid parameters
-
- Returns:
- Surface area in square units
- """
- if "dimensions" not in cuboid_params:
- raise ValueError("cuboid_params must contain 'dimensions' key")
-
- dims = cuboid_params["dimensions"]
- return 2.0 * (dims[0] * dims[1] + dims[1] * dims[2] + dims[2] * dims[0]) # type: ignore[no-any-return]
-
-
-def check_cuboid_quality(cuboid_params: dict, points: np.ndarray) -> dict: # type: ignore[type-arg]
- """
- Assess the quality of a cuboid fit.
-
- Args:
- cuboid_params: Dictionary containing cuboid parameters
- points: Original points used for fitting
-
- Returns:
- Dictionary with quality metrics
- """
- if len(points) == 0:
- return {"error": "No points provided"}
-
- # Basic metrics
- volume = compute_cuboid_volume(cuboid_params)
- surface_area = compute_cuboid_surface_area(cuboid_params)
- error = cuboid_params.get("error", 0.0)
-
- # Aspect ratio analysis
- dims = cuboid_params["dimensions"]
- aspect_ratios = [
- dims[0] / dims[1] if dims[1] > 0 else float("inf"),
- dims[1] / dims[2] if dims[2] > 0 else float("inf"),
- dims[2] / dims[0] if dims[0] > 0 else float("inf"),
- ]
- max_aspect_ratio = max(aspect_ratios)
-
- # Volume ratio (cuboid volume vs convex hull volume)
- try:
- pcd = o3d.geometry.PointCloud()
- pcd.points = o3d.utility.Vector3dVector(points)
- hull, _ = pcd.compute_convex_hull()
- hull_volume = hull.get_volume()
- volume_ratio = volume / hull_volume if hull_volume > 0 else float("inf")
- except:
- volume_ratio = None
-
- return {
- "fitting_error": error,
- "volume": volume,
- "surface_area": surface_area,
- "max_aspect_ratio": max_aspect_ratio,
- "volume_ratio": volume_ratio,
- "num_points": len(points),
- "method": cuboid_params.get("method", "unknown"),
- }
-
-
-# Backward compatibility
-def visualize_fit(image, cuboid_params, camera_matrix, R=None, t=None): # type: ignore[no-untyped-def]
- """
- Legacy function for backward compatibility.
- Use visualize_cuboid_on_image instead.
- """
- return visualize_cuboid_on_image(
- image, cuboid_params, camera_matrix, R, t, show_dimensions=True
- )
diff --git a/dimos/perception/pointcloud/pointcloud_filtering.py b/dimos/perception/pointcloud/pointcloud_filtering.py
deleted file mode 100644
index d6aa2b835f..0000000000
--- a/dimos/perception/pointcloud/pointcloud_filtering.py
+++ /dev/null
@@ -1,370 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import cv2
-import numpy as np
-import open3d as o3d # type: ignore[import-untyped]
-import torch
-
-from dimos.perception.pointcloud.cuboid_fit import fit_cuboid
-from dimos.perception.pointcloud.utils import (
- create_point_cloud_and_extract_masks,
- load_camera_matrix_from_yaml,
-)
-from dimos.types.manipulation import ObjectData
-from dimos.types.vector import Vector
-
-
-class PointcloudFiltering:
- """
- A production-ready point cloud filtering pipeline for segmented objects.
-
- This class takes segmentation results and produces clean, filtered point clouds
- for each object with consistent coloring and optional outlier removal.
- """
-
- def __init__(
- self,
- color_intrinsics: str | list[float] | np.ndarray | None = None, # type: ignore[type-arg]
- depth_intrinsics: str | list[float] | np.ndarray | None = None, # type: ignore[type-arg]
- color_weight: float = 0.3,
- enable_statistical_filtering: bool = True,
- statistical_neighbors: int = 20,
- statistical_std_ratio: float = 1.5,
- enable_radius_filtering: bool = True,
- radius_filtering_radius: float = 0.015,
- radius_filtering_min_neighbors: int = 25,
- enable_subsampling: bool = True,
- voxel_size: float = 0.005,
- max_num_objects: int = 10,
- min_points_for_cuboid: int = 10,
- cuboid_method: str = "oriented",
- max_bbox_size_percent: float = 30.0,
- ) -> None:
- """
- Initialize the point cloud filtering pipeline.
-
- Args:
- color_intrinsics: Camera intrinsics for color image
- depth_intrinsics: Camera intrinsics for depth image
- color_weight: Weight for blending generated color with original (0.0-1.0)
- enable_statistical_filtering: Enable/disable statistical outlier filtering
- statistical_neighbors: Number of neighbors for statistical filtering
- statistical_std_ratio: Standard deviation ratio for statistical filtering
- enable_radius_filtering: Enable/disable radius outlier filtering
- radius_filtering_radius: Search radius for radius filtering (meters)
- radius_filtering_min_neighbors: Min neighbors within radius
- enable_subsampling: Enable/disable point cloud subsampling
- voxel_size: Voxel size for downsampling (meters, when subsampling enabled)
- max_num_objects: Maximum number of objects to process (top N by confidence)
- min_points_for_cuboid: Minimum points required for cuboid fitting
- cuboid_method: Method for cuboid fitting ('minimal', 'oriented', 'axis_aligned')
- max_bbox_size_percent: Maximum percentage of image size for object bboxes (0-100)
-
- Raises:
- ValueError: If invalid parameters are provided
- """
- # Validate parameters
- if not 0.0 <= color_weight <= 1.0:
- raise ValueError(f"color_weight must be between 0.0 and 1.0, got {color_weight}")
- if not 0.0 <= max_bbox_size_percent <= 100.0:
- raise ValueError(
- f"max_bbox_size_percent must be between 0.0 and 100.0, got {max_bbox_size_percent}"
- )
-
- # Store settings
- self.color_weight = color_weight
- self.enable_statistical_filtering = enable_statistical_filtering
- self.statistical_neighbors = statistical_neighbors
- self.statistical_std_ratio = statistical_std_ratio
- self.enable_radius_filtering = enable_radius_filtering
- self.radius_filtering_radius = radius_filtering_radius
- self.radius_filtering_min_neighbors = radius_filtering_min_neighbors
- self.enable_subsampling = enable_subsampling
- self.voxel_size = voxel_size
- self.max_num_objects = max_num_objects
- self.min_points_for_cuboid = min_points_for_cuboid
- self.cuboid_method = cuboid_method
- self.max_bbox_size_percent = max_bbox_size_percent
-
- # Load camera matrices
- self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics)
- self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics)
-
- # Store the full point cloud
- self.full_pcd = None
-
- def generate_color_from_id(self, object_id: int) -> np.ndarray: # type: ignore[type-arg]
- """Generate a consistent color for a given object ID."""
- np.random.seed(object_id)
- color = np.random.randint(0, 255, 3, dtype=np.uint8)
- np.random.seed(None)
- return color
-
- def _validate_inputs( # type: ignore[no-untyped-def]
- self,
- color_img: np.ndarray, # type: ignore[type-arg]
- depth_img: np.ndarray, # type: ignore[type-arg]
- objects: list[ObjectData],
- ):
- """Validate input parameters."""
- if color_img.shape[:2] != depth_img.shape:
- raise ValueError("Color and depth image dimensions don't match")
-
- def _prepare_masks(self, masks: list[np.ndarray], target_shape: tuple) -> list[np.ndarray]: # type: ignore[type-arg]
- """Prepare and validate masks to match target shape."""
- processed_masks = []
- for mask in masks:
- # Convert mask to numpy if it's a tensor
- if hasattr(mask, "cpu"):
- mask = mask.cpu().numpy()
-
- mask = mask.astype(bool)
-
- # Handle shape mismatches
- if mask.shape != target_shape:
- if len(mask.shape) > 2:
- mask = mask[:, :, 0]
-
- if mask.shape != target_shape:
- mask = cv2.resize(
- mask.astype(np.uint8),
- (target_shape[1], target_shape[0]),
- interpolation=cv2.INTER_NEAREST,
- ).astype(bool)
-
- processed_masks.append(mask)
-
- return processed_masks
-
- def _apply_color_mask(
- self,
- pcd: o3d.geometry.PointCloud,
- rgb_color: np.ndarray, # type: ignore[type-arg]
- ) -> o3d.geometry.PointCloud:
- """Apply weighted color mask to point cloud."""
- if len(np.asarray(pcd.colors)) > 0:
- original_colors = np.asarray(pcd.colors)
- generated_color = rgb_color.astype(np.float32) / 255.0
- colored_mask = (
- 1.0 - self.color_weight
- ) * original_colors + self.color_weight * generated_color
- colored_mask = np.clip(colored_mask, 0.0, 1.0)
- pcd.colors = o3d.utility.Vector3dVector(colored_mask)
- return pcd
-
- def _apply_filtering(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud:
- """Apply optional filtering to point cloud based on enabled flags."""
- current_pcd = pcd
-
- # Apply statistical filtering if enabled
- if self.enable_statistical_filtering:
- current_pcd, _ = current_pcd.remove_statistical_outlier(
- nb_neighbors=self.statistical_neighbors, std_ratio=self.statistical_std_ratio
- )
-
- # Apply radius filtering if enabled
- if self.enable_radius_filtering:
- current_pcd, _ = current_pcd.remove_radius_outlier(
- nb_points=self.radius_filtering_min_neighbors, radius=self.radius_filtering_radius
- )
-
- return current_pcd
-
- def _apply_subsampling(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud:
- """Apply subsampling to limit point cloud size using Open3D's voxel downsampling."""
- if self.enable_subsampling:
- return pcd.voxel_down_sample(self.voxel_size)
- return pcd
-
- def _extract_masks_from_objects(self, objects: list[ObjectData]) -> list[np.ndarray]: # type: ignore[type-arg]
- """Extract segmentation masks from ObjectData objects."""
- return [obj["segmentation_mask"] for obj in objects]
-
- def get_full_point_cloud(self) -> o3d.geometry.PointCloud:
- """Get the full point cloud."""
- return self._apply_subsampling(self.full_pcd)
-
- def process_images(
- self,
- color_img: np.ndarray, # type: ignore[type-arg]
- depth_img: np.ndarray, # type: ignore[type-arg]
- objects: list[ObjectData],
- ) -> list[ObjectData]:
- """
- Process color and depth images with object detection results to create filtered point clouds.
-
- Args:
- color_img: RGB image as numpy array (H, W, 3)
- depth_img: Depth image as numpy array (H, W) in meters
- objects: List of ObjectData from object detection stream
-
- Returns:
- List of updated ObjectData with pointcloud and 3D information. Each ObjectData
- dictionary is enhanced with the following new fields:
-
- **3D Spatial Information** (added when sufficient points for cuboid fitting):
- - "position": Vector(x, y, z) - 3D center position in world coordinates (meters)
- - "rotation": Vector(roll, pitch, yaw) - 3D orientation as Euler angles (radians)
- - "size": {"width": float, "height": float, "depth": float} - 3D bounding box dimensions (meters)
-
- **Point Cloud Data**:
- - "point_cloud": o3d.geometry.PointCloud - Filtered Open3D point cloud with colors
- - "color": np.ndarray - Consistent RGB color [R,G,B] (0-255) generated from object_id
-
- **Grasp Generation Arrays** (Dimensional grasp format):
- - "point_cloud_numpy": np.ndarray - Nx3 XYZ coordinates as float32 (meters)
- - "colors_numpy": np.ndarray - Nx3 RGB colors as float32 (0.0-1.0 range)
-
- Raises:
- ValueError: If inputs are invalid
- RuntimeError: If processing fails
- """
- # Validate inputs
- self._validate_inputs(color_img, depth_img, objects)
-
- if not objects:
- return []
-
- # Filter to top N objects by confidence
- if len(objects) > self.max_num_objects:
- # Sort objects by confidence (highest first), handle None confidences
- sorted_objects = sorted(
- objects,
- key=lambda obj: obj.get("confidence", 0.0)
- if obj.get("confidence") is not None
- else 0.0,
- reverse=True,
- )
- objects = sorted_objects[: self.max_num_objects]
-
- # Filter out objects with bboxes too large
- image_area = color_img.shape[0] * color_img.shape[1]
- max_bbox_area = image_area * (self.max_bbox_size_percent / 100.0)
-
- filtered_objects = []
- for obj in objects:
- if "bbox" in obj and obj["bbox"] is not None:
- bbox = obj["bbox"]
- # Calculate bbox area (assuming bbox format [x1, y1, x2, y2])
- bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
- if bbox_area <= max_bbox_area:
- filtered_objects.append(obj)
- else:
- filtered_objects.append(obj)
-
- objects = filtered_objects
-
- # Extract masks from ObjectData
- masks = self._extract_masks_from_objects(objects)
-
- # Prepare masks
- processed_masks = self._prepare_masks(masks, depth_img.shape)
-
- # Create point clouds efficiently
- self.full_pcd, masked_pcds = create_point_cloud_and_extract_masks(
- color_img,
- depth_img,
- processed_masks,
- self.depth_camera_matrix, # type: ignore[arg-type]
- depth_scale=1.0,
- )
-
- # Process each object and update ObjectData
- updated_objects = []
-
- for i, (obj, _mask, pcd) in enumerate(
- zip(objects, processed_masks, masked_pcds, strict=False)
- ):
- # Skip empty point clouds
- if len(np.asarray(pcd.points)) == 0:
- continue
-
- # Create a copy of the object data to avoid modifying the original
- updated_obj = obj.copy()
-
- # Generate consistent color
- object_id = obj.get("object_id", i)
- rgb_color = self.generate_color_from_id(object_id)
-
- # Apply color mask
- pcd = self._apply_color_mask(pcd, rgb_color)
-
- # Apply subsampling to control point cloud size
- pcd = self._apply_subsampling(pcd)
-
- # Apply filtering (optional based on flags)
- pcd_filtered = self._apply_filtering(pcd)
-
- # Fit cuboid and extract 3D information
- points = np.asarray(pcd_filtered.points)
- if len(points) >= self.min_points_for_cuboid:
- cuboid_params = fit_cuboid(points, method=self.cuboid_method)
- if cuboid_params is not None:
- # Update position, rotation, and size from cuboid
- center = cuboid_params["center"]
- dimensions = cuboid_params["dimensions"]
- rotation_matrix = cuboid_params["rotation"]
-
- # Convert rotation matrix to euler angles (roll, pitch, yaw)
- sy = np.sqrt(
- rotation_matrix[0, 0] * rotation_matrix[0, 0]
- + rotation_matrix[1, 0] * rotation_matrix[1, 0]
- )
- singular = sy < 1e-6
-
- if not singular:
- roll = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2])
- pitch = np.arctan2(-rotation_matrix[2, 0], sy)
- yaw = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0])
- else:
- roll = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1])
- pitch = np.arctan2(-rotation_matrix[2, 0], sy)
- yaw = 0
-
- # Update position, rotation, and size from cuboid
- updated_obj["position"] = Vector(center[0], center[1], center[2])
- updated_obj["rotation"] = Vector(roll, pitch, yaw)
- updated_obj["size"] = {
- "width": float(dimensions[0]),
- "height": float(dimensions[1]),
- "depth": float(dimensions[2]),
- }
-
- # Add point cloud data to ObjectData
- updated_obj["point_cloud"] = pcd_filtered
- updated_obj["color"] = rgb_color
-
- # Extract numpy arrays for grasp generation
- points_array = np.asarray(pcd_filtered.points).astype(np.float32) # Nx3 XYZ coordinates
- if pcd_filtered.has_colors():
- colors_array = np.asarray(pcd_filtered.colors).astype(
- np.float32
- ) # Nx3 RGB (0-1 range)
- else:
- # If no colors, create array of zeros
- colors_array = np.zeros((len(points_array), 3), dtype=np.float32)
-
- updated_obj["point_cloud_numpy"] = points_array
- updated_obj["colors_numpy"] = colors_array # type: ignore[typeddict-unknown-key]
-
- updated_objects.append(updated_obj)
-
- return updated_objects
-
- def cleanup(self) -> None:
- """Clean up resources."""
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
diff --git a/dimos/perception/pointcloud/test_pointcloud_filtering.py b/dimos/perception/pointcloud/test_pointcloud_filtering.py
deleted file mode 100644
index 4ac7e5cb2d..0000000000
--- a/dimos/perception/pointcloud/test_pointcloud_filtering.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from typing import TYPE_CHECKING
-
-import cv2
-import numpy as np
-import open3d as o3d
-import pytest
-
-from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering
-from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml
-
-if TYPE_CHECKING:
- from dimos.types.manipulation import ObjectData
-
-
-class TestPointcloudFiltering:
- def test_pointcloud_filtering_initialization(self) -> None:
- """Test PointcloudFiltering initializes correctly with default parameters."""
- try:
- filtering = PointcloudFiltering()
- assert filtering is not None
- assert filtering.color_weight == 0.3
- assert filtering.enable_statistical_filtering
- assert filtering.enable_radius_filtering
- assert filtering.enable_subsampling
- except Exception as e:
- pytest.skip(f"Skipping test due to initialization error: {e}")
-
- def test_pointcloud_filtering_with_custom_params(self) -> None:
- """Test PointcloudFiltering with custom parameters."""
- try:
- filtering = PointcloudFiltering(
- color_weight=0.5,
- enable_statistical_filtering=False,
- enable_radius_filtering=False,
- voxel_size=0.01,
- max_num_objects=5,
- )
- assert filtering.color_weight == 0.5
- assert not filtering.enable_statistical_filtering
- assert not filtering.enable_radius_filtering
- assert filtering.voxel_size == 0.01
- assert filtering.max_num_objects == 5
- except Exception as e:
- pytest.skip(f"Skipping test due to initialization error: {e}")
-
- def test_pointcloud_filtering_process_images(self) -> None:
- """Test PointcloudFiltering can process RGB-D images and return filtered point clouds."""
- try:
- # Import data inside method to avoid pytest fixture confusion
- from dimos.utils.data import get_data
-
- # Load test RGB-D data
- data_dir = get_data("rgbd_frames")
-
- # Load first frame
- color_path = os.path.join(data_dir, "color", "00000.png")
- depth_path = os.path.join(data_dir, "depth", "00000.png")
- intrinsics_path = os.path.join(data_dir, "color_camera_info.yaml")
-
- assert os.path.exists(color_path), f"Color image not found: {color_path}"
- assert os.path.exists(depth_path), f"Depth image not found: {depth_path}"
- assert os.path.exists(intrinsics_path), f"Intrinsics file not found: {intrinsics_path}"
-
- # Load images
- color_img = cv2.imread(color_path)
- color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)
-
- depth_img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
- if depth_img.dtype == np.uint16:
- depth_img = depth_img.astype(np.float32) / 1000.0
-
- # Load camera intrinsics
- camera_matrix = load_camera_matrix_from_yaml(intrinsics_path)
- if camera_matrix is None:
- pytest.skip("Failed to load camera intrinsics")
-
- # Create mock objects with segmentation masks
- height, width = color_img.shape[:2]
-
- # Create simple rectangular masks for testing
- mock_objects = []
-
- # Object 1: Top-left quadrant
- mask1 = np.zeros((height, width), dtype=bool)
- mask1[height // 4 : height // 2, width // 4 : width // 2] = True
-
- obj1: ObjectData = {
- "object_id": 1,
- "confidence": 0.9,
- "bbox": [width // 4, height // 4, width // 2, height // 2],
- "segmentation_mask": mask1,
- "name": "test_object_1",
- }
- mock_objects.append(obj1)
-
- # Object 2: Bottom-right quadrant
- mask2 = np.zeros((height, width), dtype=bool)
- mask2[height // 2 : 3 * height // 4, width // 2 : 3 * width // 4] = True
-
- obj2: ObjectData = {
- "object_id": 2,
- "confidence": 0.8,
- "bbox": [width // 2, height // 2, 3 * width // 4, 3 * height // 4],
- "segmentation_mask": mask2,
- "name": "test_object_2",
- }
- mock_objects.append(obj2)
-
- # Initialize filtering with intrinsics
- filtering = PointcloudFiltering(
- color_intrinsics=camera_matrix,
- depth_intrinsics=camera_matrix,
- enable_statistical_filtering=False, # Disable for faster testing
- enable_radius_filtering=False, # Disable for faster testing
- voxel_size=0.01, # Larger voxel for faster processing
- )
-
- # Process images
- results = filtering.process_images(color_img, depth_img, mock_objects)
-
- print(
- f"Processing results - Input objects: {len(mock_objects)}, Output objects: {len(results)}"
- )
-
- # Verify results
- assert isinstance(results, list), "Results should be a list"
- assert len(results) <= len(mock_objects), "Should not return more objects than input"
-
- # Check each result object
- for i, result in enumerate(results):
- print(f"Object {i}: {result.get('name', 'unknown')}")
-
- # Verify required fields exist
- assert "point_cloud" in result, "Result should contain point_cloud"
- assert "color" in result, "Result should contain color"
- assert "point_cloud_numpy" in result, "Result should contain point_cloud_numpy"
-
- # Verify point cloud is valid Open3D object
- pcd = result["point_cloud"]
- assert isinstance(pcd, o3d.geometry.PointCloud), (
- "point_cloud should be Open3D PointCloud"
- )
-
- # Verify numpy arrays
- points_array = result["point_cloud_numpy"]
- assert isinstance(points_array, np.ndarray), (
- "point_cloud_numpy should be numpy array"
- )
- assert points_array.shape[1] == 3, "Point array should have 3 columns (x,y,z)"
- assert points_array.dtype == np.float32, "Point array should be float32"
-
- # Verify color
- color = result["color"]
- assert isinstance(color, np.ndarray), "Color should be numpy array"
- assert color.shape == (3,), "Color should be RGB triplet"
- assert color.dtype == np.uint8, "Color should be uint8"
-
- # Check if 3D information was added (when enough points for cuboid fitting)
- points = np.asarray(pcd.points)
- if len(points) >= filtering.min_points_for_cuboid:
- if "position" in result:
- assert "rotation" in result, "Should have rotation if position exists"
- assert "size" in result, "Should have size if position exists"
-
- # Verify position format
- from dimos.types.vector import Vector
-
- position = result["position"]
- assert isinstance(position, Vector), "Position should be Vector"
-
- # Verify size format
- size = result["size"]
- assert isinstance(size, dict), "Size should be dict"
- assert "width" in size and "height" in size and "depth" in size
-
- print(f" - Points: {len(points)}")
- print(f" - Color: {color}")
- if "position" in result:
- print(f" - Position: {result['position']}")
- print(f" - Size: {result['size']}")
-
- # Test full point cloud access
- full_pcd = filtering.get_full_point_cloud()
- if full_pcd is not None:
- assert isinstance(full_pcd, o3d.geometry.PointCloud), (
- "Full point cloud should be Open3D PointCloud"
- )
- full_points = np.asarray(full_pcd.points)
- print(f"Full point cloud points: {len(full_points)}")
-
- print("All pointcloud filtering tests passed!")
-
- except Exception as e:
- pytest.skip(f"Skipping test due to error: {e}")
-
- def test_pointcloud_filtering_empty_objects(self) -> None:
- """Test PointcloudFiltering with empty object list."""
- try:
- from dimos.utils.data import get_data
-
- # Load test data
- data_dir = get_data("rgbd_frames")
- color_path = os.path.join(data_dir, "color", "00000.png")
- depth_path = os.path.join(data_dir, "depth", "00000.png")
-
- if not (os.path.exists(color_path) and os.path.exists(depth_path)):
- pytest.skip("Test images not found")
-
- color_img = cv2.imread(color_path)
- color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)
- depth_img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
- if depth_img.dtype == np.uint16:
- depth_img = depth_img.astype(np.float32) / 1000.0
-
- filtering = PointcloudFiltering()
-
- # Test with empty object list
- results = filtering.process_images(color_img, depth_img, [])
-
- assert isinstance(results, list), "Results should be a list"
- assert len(results) == 0, "Should return empty list for empty input"
-
- except Exception as e:
- pytest.skip(f"Skipping test due to error: {e}")
-
- def test_color_generation_consistency(self) -> None:
- """Test that color generation is consistent for the same object ID."""
- try:
- filtering = PointcloudFiltering()
-
- # Test color generation consistency
- color1 = filtering.generate_color_from_id(42)
- color2 = filtering.generate_color_from_id(42)
- color3 = filtering.generate_color_from_id(43)
-
- assert np.array_equal(color1, color2), "Same ID should generate same color"
- assert not np.array_equal(color1, color3), (
- "Different IDs should generate different colors"
- )
- assert color1.shape == (3,), "Color should be RGB triplet"
- assert color1.dtype == np.uint8, "Color should be uint8"
-
- except Exception as e:
- pytest.skip(f"Skipping test due to error: {e}")
-
-
-if __name__ == "__main__":
- pytest.main(["-v", __file__])
diff --git a/dimos/perception/pointcloud/utils.py b/dimos/perception/pointcloud/utils.py
deleted file mode 100644
index b2bb561000..0000000000
--- a/dimos/perception/pointcloud/utils.py
+++ /dev/null
@@ -1,1113 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Point cloud utilities for RGBD data processing.
-
-This module provides efficient utilities for creating and manipulating point clouds
-from RGBD images using Open3D.
-"""
-
-import os
-from typing import Any
-
-import cv2
-import numpy as np
-import open3d as o3d # type: ignore[import-untyped]
-from scipy.spatial import cKDTree # type: ignore[import-untyped]
-import yaml # type: ignore[import-untyped]
-
-from dimos.perception.common.utils import project_3d_points_to_2d
-
-
-def load_camera_matrix_from_yaml(
- camera_info: str | list[float] | np.ndarray | dict | None, # type: ignore[type-arg]
-) -> np.ndarray | None: # type: ignore[type-arg]
- """
- Load camera intrinsic matrix from various input formats.
-
- Args:
- camera_info: Can be:
- - Path to YAML file containing camera parameters
- - List of [fx, fy, cx, cy]
- - 3x3 numpy array (returned as-is)
- - Dict with camera parameters
- - None (returns None)
-
- Returns:
- 3x3 camera intrinsic matrix or None if input is None
-
- Raises:
- ValueError: If camera_info format is invalid or file cannot be read
- FileNotFoundError: If YAML file path doesn't exist
- """
- if camera_info is None:
- return None
-
- # Handle case where camera_info is already a matrix
- if isinstance(camera_info, np.ndarray) and camera_info.shape == (3, 3):
- return camera_info.astype(np.float32)
-
- # Handle case where camera_info is [fx, fy, cx, cy] format
- if isinstance(camera_info, list) and len(camera_info) == 4:
- fx, fy, cx, cy = camera_info
- return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
-
- # Handle case where camera_info is a dict
- if isinstance(camera_info, dict):
- return _extract_matrix_from_dict(camera_info)
-
- # Handle case where camera_info is a path to a YAML file
- if isinstance(camera_info, str):
- if not os.path.isfile(camera_info):
- raise FileNotFoundError(f"Camera info file not found: {camera_info}")
-
- try:
- with open(camera_info) as f:
- data = yaml.safe_load(f)
- return _extract_matrix_from_dict(data)
- except Exception as e:
- raise ValueError(f"Failed to read camera info from {camera_info}: {e}")
-
- raise ValueError(
- f"Invalid camera_info format. Expected str, list, dict, or numpy array, got {type(camera_info)}"
- )
-
-
-def _extract_matrix_from_dict(data: dict) -> np.ndarray: # type: ignore[type-arg]
- """Extract camera matrix from dictionary with various formats."""
- # ROS format with 'K' field (most common)
- if "K" in data:
- k_data = data["K"]
- if len(k_data) == 9:
- return np.array(k_data, dtype=np.float32).reshape(3, 3)
-
- # Standard format with 'camera_matrix'
- if "camera_matrix" in data:
- if "data" in data["camera_matrix"]:
- matrix_data = data["camera_matrix"]["data"]
- if len(matrix_data) == 9:
- return np.array(matrix_data, dtype=np.float32).reshape(3, 3)
-
- # Explicit intrinsics format
- if all(k in data for k in ["fx", "fy", "cx", "cy"]):
- fx, fy = float(data["fx"]), float(data["fy"])
- cx, cy = float(data["cx"]), float(data["cy"])
- return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
-
- # Error case - provide helpful debug info
- available_keys = list(data.keys())
- if "K" in data:
- k_info = f"K field length: {len(data['K']) if hasattr(data['K'], '__len__') else 'unknown'}"
- else:
- k_info = "K field not found"
-
- raise ValueError(
- f"Cannot extract camera matrix from data. "
- f"Available keys: {available_keys}. {k_info}. "
- f"Expected formats: 'K' (9 elements), 'camera_matrix.data' (9 elements), "
- f"or individual 'fx', 'fy', 'cx', 'cy' fields."
- )
-
-
-def create_o3d_point_cloud_from_rgbd(
- color_img: np.ndarray, # type: ignore[type-arg]
- depth_img: np.ndarray, # type: ignore[type-arg]
- intrinsic: np.ndarray, # type: ignore[type-arg]
- depth_scale: float = 1.0,
- depth_trunc: float = 3.0,
-) -> o3d.geometry.PointCloud:
- """
- Create an Open3D point cloud from RGB and depth images.
-
- Args:
- color_img: RGB image as numpy array (H, W, 3)
- depth_img: Depth image as numpy array (H, W)
- intrinsic: Camera intrinsic matrix (3x3 numpy array)
- depth_scale: Scale factor to convert depth to meters
- depth_trunc: Maximum depth in meters
-
- Returns:
- Open3D point cloud object
-
- Raises:
- ValueError: If input dimensions are invalid
- """
- # Validate inputs
- if len(color_img.shape) != 3 or color_img.shape[2] != 3:
- raise ValueError(f"color_img must be (H, W, 3), got {color_img.shape}")
- if len(depth_img.shape) != 2:
- raise ValueError(f"depth_img must be (H, W), got {depth_img.shape}")
- if color_img.shape[:2] != depth_img.shape:
- raise ValueError(
- f"Color and depth image dimensions don't match: {color_img.shape[:2]} vs {depth_img.shape}"
- )
- if intrinsic.shape != (3, 3):
- raise ValueError(f"intrinsic must be (3, 3), got {intrinsic.shape}")
-
- # Convert to Open3D format
- color_o3d = o3d.geometry.Image(color_img.astype(np.uint8))
-
- # Filter out inf and nan values from depth image
- depth_filtered = depth_img.copy()
-
- # Create mask for valid depth values (finite, positive, non-zero)
- valid_mask = np.isfinite(depth_filtered) & (depth_filtered > 0)
-
- # Set invalid values to 0 (which Open3D treats as no depth)
- depth_filtered[~valid_mask] = 0.0
-
- depth_o3d = o3d.geometry.Image(depth_filtered.astype(np.float32))
-
- # Create Open3D intrinsic object
- height, width = color_img.shape[:2]
- fx, fy = intrinsic[0, 0], intrinsic[1, 1]
- cx, cy = intrinsic[0, 2], intrinsic[1, 2]
- intrinsic_o3d = o3d.camera.PinholeCameraIntrinsic(
- width,
- height,
- fx,
- fy, # fx, fy
- cx,
- cy, # cx, cy
- )
-
- # Create RGBD image
- rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
- color_o3d,
- depth_o3d,
- depth_scale=depth_scale,
- depth_trunc=depth_trunc,
- convert_rgb_to_intensity=False,
- )
-
- # Create point cloud
- pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, intrinsic_o3d)
-
- return pcd
-
-
-def create_point_cloud_and_extract_masks(
- color_img: np.ndarray, # type: ignore[type-arg]
- depth_img: np.ndarray, # type: ignore[type-arg]
- masks: list[np.ndarray], # type: ignore[type-arg]
- intrinsic: np.ndarray, # type: ignore[type-arg]
- depth_scale: float = 1.0,
- depth_trunc: float = 3.0,
-) -> tuple[o3d.geometry.PointCloud, list[o3d.geometry.PointCloud]]:
- """
- Efficiently create a point cloud once and extract multiple masked regions.
-
- Args:
- color_img: RGB image (H, W, 3)
- depth_img: Depth image (H, W)
- masks: List of boolean masks, each of shape (H, W)
- intrinsic: Camera intrinsic matrix (3x3 numpy array)
- depth_scale: Scale factor to convert depth to meters
- depth_trunc: Maximum depth in meters
-
- Returns:
- Tuple of (full_point_cloud, list_of_masked_point_clouds)
- """
- if not masks:
- return o3d.geometry.PointCloud(), []
-
- # Create the full point cloud
- full_pcd = create_o3d_point_cloud_from_rgbd(
- color_img, depth_img, intrinsic, depth_scale, depth_trunc
- )
-
- if len(np.asarray(full_pcd.points)) == 0:
- return full_pcd, [o3d.geometry.PointCloud() for _ in masks]
-
- # Create pixel-to-point mapping
- valid_depth_mask = np.isfinite(depth_img) & (depth_img > 0) & (depth_img <= depth_trunc)
-
- valid_depth = valid_depth_mask.flatten()
- if not np.any(valid_depth):
- return full_pcd, [o3d.geometry.PointCloud() for _ in masks]
-
- pixel_to_point = np.full(len(valid_depth), -1, dtype=np.int32)
- pixel_to_point[valid_depth] = np.arange(np.sum(valid_depth))
-
- # Extract point clouds for each mask
- masked_pcds = []
- max_points = len(np.asarray(full_pcd.points))
-
- for mask in masks:
- if mask.shape != depth_img.shape:
- masked_pcds.append(o3d.geometry.PointCloud())
- continue
-
- mask_flat = mask.flatten()
- valid_mask_indices = mask_flat & valid_depth
- point_indices = pixel_to_point[valid_mask_indices]
- valid_point_indices = point_indices[point_indices >= 0]
-
- if len(valid_point_indices) > 0:
- valid_point_indices = np.clip(valid_point_indices, 0, max_points - 1)
- valid_point_indices = np.unique(valid_point_indices)
- masked_pcd = full_pcd.select_by_index(valid_point_indices.tolist())
- else:
- masked_pcd = o3d.geometry.PointCloud()
-
- masked_pcds.append(masked_pcd)
-
- return full_pcd, masked_pcds
-
-
-def filter_point_cloud_statistical(
- pcd: o3d.geometry.PointCloud, nb_neighbors: int = 20, std_ratio: float = 2.0
-) -> tuple[o3d.geometry.PointCloud, np.ndarray]: # type: ignore[type-arg]
- """
- Apply statistical outlier filtering to point cloud.
-
- Args:
- pcd: Input point cloud
- nb_neighbors: Number of neighbors to analyze for each point
- std_ratio: Threshold level based on standard deviation
-
- Returns:
- Tuple of (filtered_point_cloud, outlier_indices)
- """
- if len(np.asarray(pcd.points)) == 0:
- return pcd, np.array([])
-
- return pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) # type: ignore[no-any-return]
-
-
-def filter_point_cloud_radius(
- pcd: o3d.geometry.PointCloud, nb_points: int = 16, radius: float = 0.05
-) -> tuple[o3d.geometry.PointCloud, np.ndarray]: # type: ignore[type-arg]
- """
- Apply radius-based outlier filtering to point cloud.
-
- Args:
- pcd: Input point cloud
- nb_points: Minimum number of points within radius
- radius: Search radius in meters
-
- Returns:
- Tuple of (filtered_point_cloud, outlier_indices)
- """
- if len(np.asarray(pcd.points)) == 0:
- return pcd, np.array([])
-
- return pcd.remove_radius_outlier(nb_points=nb_points, radius=radius) # type: ignore[no-any-return]
-
-
-def overlay_point_clouds_on_image(
- base_image: np.ndarray, # type: ignore[type-arg]
- point_clouds: list[o3d.geometry.PointCloud],
- camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg]
- colors: list[tuple[int, int, int]],
- point_size: int = 2,
- alpha: float = 0.7,
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Overlay multiple colored point clouds onto an image.
-
- Args:
- base_image: Base image to overlay onto (H, W, 3) - assumed to be RGB
- point_clouds: List of Open3D point cloud objects
- camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix
- colors: List of RGB color tuples for each point cloud. If None, generates distinct colors.
- point_size: Size of points to draw (in pixels)
- alpha: Blending factor for overlay (0.0 = fully transparent, 1.0 = fully opaque)
-
- Returns:
- Image with overlaid point clouds (H, W, 3)
- """
- if len(point_clouds) == 0:
- return base_image.copy()
-
- # Create overlay image
- overlay = base_image.copy()
- height, width = base_image.shape[:2]
-
- # Process each point cloud
- for i, pcd in enumerate(point_clouds):
- if pcd is None:
- continue
-
- points_3d = np.asarray(pcd.points)
- if len(points_3d) == 0:
- continue
-
- # Project 3D points to 2D
- points_2d = project_3d_points_to_2d(points_3d, camera_intrinsics)
-
- if len(points_2d) == 0:
- continue
-
- # Filter points within image bounds
- valid_mask = (
- (points_2d[:, 0] >= 0)
- & (points_2d[:, 0] < width)
- & (points_2d[:, 1] >= 0)
- & (points_2d[:, 1] < height)
- )
- valid_points_2d = points_2d[valid_mask]
-
- if len(valid_points_2d) == 0:
- continue
-
- # Get color for this point cloud
- color = colors[i % len(colors)]
-
- # Ensure color is a tuple of integers for OpenCV
- if isinstance(color, list | tuple | np.ndarray):
- color = tuple(int(c) for c in color[:3]) # type: ignore[assignment]
- else:
- color = (255, 255, 255)
-
- # Draw points on overlay
- for point in valid_points_2d:
- u, v = point
- # Draw a small filled circle for each point
- cv2.circle(overlay, (u, v), point_size, color, -1)
-
- # Blend overlay with base image
- result = cv2.addWeighted(base_image, 1 - alpha, overlay, alpha, 0)
-
- return result
-
-
-def create_point_cloud_overlay_visualization(
- base_image: np.ndarray, # type: ignore[type-arg]
- objects: list[dict], # type: ignore[type-arg]
- intrinsics: np.ndarray, # type: ignore[type-arg]
-) -> np.ndarray: # type: ignore[type-arg]
- """
- Create a visualization showing object point clouds and bounding boxes overlaid on a base image.
-
- Args:
- base_image: Base image to overlay onto (H, W, 3)
- objects: List of object dictionaries containing 'point_cloud', 'color', 'position', 'rotation', 'size' keys
- intrinsics: Camera intrinsics as [fx, fy, cx, cy] or 3x3 matrix
-
- Returns:
- Visualization image with overlaid point clouds and bounding boxes (H, W, 3)
- """
- # Extract point clouds and colors from objects
- point_clouds = []
- colors = []
- for obj in objects:
- if "point_cloud" in obj and obj["point_cloud"] is not None:
- point_clouds.append(obj["point_cloud"])
-
- # Convert color to tuple
- color = obj["color"]
- if isinstance(color, np.ndarray):
- color = tuple(int(c) for c in color)
- elif isinstance(color, list | tuple):
- color = tuple(int(c) for c in color[:3])
- colors.append(color)
-
- # Create visualization
- if point_clouds:
- result = overlay_point_clouds_on_image(
- base_image=base_image,
- point_clouds=point_clouds,
- camera_intrinsics=intrinsics,
- colors=colors,
- point_size=3,
- alpha=0.8,
- )
- else:
- result = base_image.copy()
-
- # Draw 3D bounding boxes
- height_img, width_img = result.shape[:2]
- for i, obj in enumerate(objects):
- if all(key in obj and obj[key] is not None for key in ["position", "rotation", "size"]):
- try:
- # Create and project 3D bounding box
- corners_3d = create_3d_bounding_box_corners(
- obj["position"], obj["rotation"], obj["size"]
- )
- corners_2d = project_3d_points_to_2d(corners_3d, intrinsics)
-
- # Check if any corners are visible
- valid_mask = (
- (corners_2d[:, 0] >= 0)
- & (corners_2d[:, 0] < width_img)
- & (corners_2d[:, 1] >= 0)
- & (corners_2d[:, 1] < height_img)
- )
-
- if np.any(valid_mask):
- # Get color
- bbox_color = colors[i] if i < len(colors) else (255, 255, 255)
- draw_3d_bounding_box_on_image(result, corners_2d, bbox_color, thickness=2)
- except:
- continue
-
- return result
-
-
-def create_3d_bounding_box_corners(position, rotation, size: int): # type: ignore[no-untyped-def]
- """
- Create 8 corners of a 3D bounding box from position, rotation, and size.
-
- Args:
- position: Vector or dict with x, y, z coordinates
- rotation: Vector or dict with roll, pitch, yaw angles
- size: Dict with width, height, depth
-
- Returns:
- 8x3 numpy array of corner coordinates
- """
- # Convert position to numpy array
- if hasattr(position, "x"): # Vector object
- center = np.array([position.x, position.y, position.z])
- else: # Dictionary
- center = np.array([position["x"], position["y"], position["z"]])
-
- # Convert rotation (euler angles) to rotation matrix
- if hasattr(rotation, "x"): # Vector object (roll, pitch, yaw)
- roll, pitch, yaw = rotation.x, rotation.y, rotation.z
- else: # Dictionary
- roll, pitch, yaw = rotation["roll"], rotation["pitch"], rotation["yaw"]
-
- # Create rotation matrix from euler angles (ZYX order)
- cos_r, sin_r = np.cos(roll), np.sin(roll)
- cos_p, sin_p = np.cos(pitch), np.sin(pitch)
- cos_y, sin_y = np.cos(yaw), np.sin(yaw)
-
- # Rotation matrix for ZYX euler angles
- R = np.array(
- [
- [
- cos_y * cos_p,
- cos_y * sin_p * sin_r - sin_y * cos_r,
- cos_y * sin_p * cos_r + sin_y * sin_r,
- ],
- [
- sin_y * cos_p,
- sin_y * sin_p * sin_r + cos_y * cos_r,
- sin_y * sin_p * cos_r - cos_y * sin_r,
- ],
- [-sin_p, cos_p * sin_r, cos_p * cos_r],
- ]
- )
-
- # Get dimensions
- width = size.get("width", 0.1) # type: ignore[attr-defined]
- height = size.get("height", 0.1) # type: ignore[attr-defined]
- depth = size.get("depth", 0.1) # type: ignore[attr-defined]
-
- # Create 8 corners of the bounding box (before rotation)
- corners = np.array(
- [
- [-width / 2, -height / 2, -depth / 2], # 0
- [width / 2, -height / 2, -depth / 2], # 1
- [width / 2, height / 2, -depth / 2], # 2
- [-width / 2, height / 2, -depth / 2], # 3
- [-width / 2, -height / 2, depth / 2], # 4
- [width / 2, -height / 2, depth / 2], # 5
- [width / 2, height / 2, depth / 2], # 6
- [-width / 2, height / 2, depth / 2], # 7
- ]
- )
-
- # Apply rotation and translation
- rotated_corners = corners @ R.T + center
-
- return rotated_corners
-
-
-def draw_3d_bounding_box_on_image(image, corners_2d, color, thickness: int = 2) -> None: # type: ignore[no-untyped-def]
- """
- Draw a 3D bounding box on an image using projected 2D corners.
-
- Args:
- image: Image to draw on
- corners_2d: 8x2 array of 2D corner coordinates
- color: RGB color tuple
- thickness: Line thickness
- """
- # Define the 12 edges of a cube (connecting corner indices)
- edges = [
- (0, 1),
- (1, 2),
- (2, 3),
- (3, 0), # Bottom face
- (4, 5),
- (5, 6),
- (6, 7),
- (7, 4), # Top face
- (0, 4),
- (1, 5),
- (2, 6),
- (3, 7), # Vertical edges
- ]
-
- # Draw each edge
- for start_idx, end_idx in edges:
- start_point = tuple(corners_2d[start_idx].astype(int))
- end_point = tuple(corners_2d[end_idx].astype(int))
- cv2.line(image, start_point, end_point, color, thickness)
-
-
-def extract_and_cluster_misc_points(
- full_pcd: o3d.geometry.PointCloud,
- all_objects: list[dict], # type: ignore[type-arg]
- eps: float = 0.03,
- min_points: int = 100,
- enable_filtering: bool = True,
- voxel_size: float = 0.02,
-) -> tuple[list[o3d.geometry.PointCloud], o3d.geometry.VoxelGrid]:
- """
- Extract miscellaneous/background points and cluster them using DBSCAN.
-
- Args:
- full_pcd: Complete scene point cloud
- all_objects: List of objects with point clouds to subtract
- eps: DBSCAN epsilon parameter (max distance between points in cluster)
- min_points: DBSCAN min_samples parameter (min points to form cluster)
- enable_filtering: Whether to apply statistical and radius filtering
- voxel_size: Size of voxels for voxel grid generation
-
- Returns:
- Tuple of (clustered_point_clouds, voxel_grid)
- """
- if full_pcd is None or len(np.asarray(full_pcd.points)) == 0:
- return [], o3d.geometry.VoxelGrid()
-
- if not all_objects:
- # If no objects detected, cluster the full point cloud
- clusters = _cluster_point_cloud_dbscan(full_pcd, eps, min_points)
- voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size)
- return clusters, voxel_grid
-
- try:
- # Start with a copy of the full point cloud
- misc_pcd = o3d.geometry.PointCloud(full_pcd)
-
- # Remove object points by combining all object point clouds
- all_object_points = []
- for obj in all_objects:
- if "point_cloud" in obj and obj["point_cloud"] is not None:
- obj_points = np.asarray(obj["point_cloud"].points)
- if len(obj_points) > 0:
- all_object_points.append(obj_points)
-
- if not all_object_points:
- # No object points to remove, cluster full point cloud
- clusters = _cluster_point_cloud_dbscan(misc_pcd, eps, min_points)
- voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size)
- return clusters, voxel_grid
-
- # Combine all object points
- combined_obj_points = np.vstack(all_object_points)
-
- # For efficiency, downsample both point clouds
- misc_downsampled = misc_pcd.voxel_down_sample(voxel_size=0.005)
-
- # Create object point cloud for efficient operations
- obj_pcd = o3d.geometry.PointCloud()
- obj_pcd.points = o3d.utility.Vector3dVector(combined_obj_points)
- obj_downsampled = obj_pcd.voxel_down_sample(voxel_size=0.005)
-
- misc_points = np.asarray(misc_downsampled.points)
- obj_points_down = np.asarray(obj_downsampled.points)
-
- if len(misc_points) == 0 or len(obj_points_down) == 0:
- clusters = _cluster_point_cloud_dbscan(misc_downsampled, eps, min_points)
- voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size)
- return clusters, voxel_grid
-
- # Build tree for object points
- obj_tree = cKDTree(obj_points_down)
-
- # Find distances from misc points to nearest object points
- distances, _ = obj_tree.query(misc_points, k=1)
-
- # Keep points that are far enough from any object point
- threshold = 0.015 # 1.5cm threshold
- keep_mask = distances > threshold
-
- if not np.any(keep_mask):
- return [], o3d.geometry.VoxelGrid()
-
- # Filter misc points
- misc_indices = np.where(keep_mask)[0]
- final_misc_pcd = misc_downsampled.select_by_index(misc_indices)
-
- if len(np.asarray(final_misc_pcd.points)) == 0:
- return [], o3d.geometry.VoxelGrid()
-
- # Apply additional filtering if enabled
- if enable_filtering:
- # Apply statistical outlier filtering
- filtered_misc_pcd, _ = filter_point_cloud_statistical(
- final_misc_pcd, nb_neighbors=30, std_ratio=2.0
- )
-
- if len(np.asarray(filtered_misc_pcd.points)) == 0:
- return [], o3d.geometry.VoxelGrid()
-
- # Apply radius outlier filtering
- final_filtered_misc_pcd, _ = filter_point_cloud_radius(
- filtered_misc_pcd,
- nb_points=20,
- radius=0.03, # 3cm radius
- )
-
- if len(np.asarray(final_filtered_misc_pcd.points)) == 0:
- return [], o3d.geometry.VoxelGrid()
-
- final_misc_pcd = final_filtered_misc_pcd
-
- # Cluster the misc points using DBSCAN
- clusters = _cluster_point_cloud_dbscan(final_misc_pcd, eps, min_points)
-
- # Create voxel grid from all misc points (before clustering)
- voxel_grid = _create_voxel_grid_from_point_cloud(final_misc_pcd, voxel_size)
-
- return clusters, voxel_grid
-
- except Exception as e:
- print(f"Error in misc point extraction and clustering: {e}")
- # Fallback: return downsampled full point cloud as single cluster
- try:
- downsampled = full_pcd.voxel_down_sample(voxel_size=0.02)
- if len(np.asarray(downsampled.points)) > 0:
- voxel_grid = _create_voxel_grid_from_point_cloud(downsampled, voxel_size)
- return [downsampled], voxel_grid
- else:
- return [], o3d.geometry.VoxelGrid()
- except:
- return [], o3d.geometry.VoxelGrid()
-
-
-def _create_voxel_grid_from_point_cloud(
- pcd: o3d.geometry.PointCloud, voxel_size: float = 0.02
-) -> o3d.geometry.VoxelGrid:
- """
- Create a voxel grid from a point cloud.
-
- Args:
- pcd: Input point cloud
- voxel_size: Size of each voxel
-
- Returns:
- Open3D VoxelGrid object
- """
- if len(np.asarray(pcd.points)) == 0:
- return o3d.geometry.VoxelGrid()
-
- try:
- # Create voxel grid from point cloud
- voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size)
-
- # Color the voxels with a semi-transparent gray
- for voxel in voxel_grid.get_voxels():
- voxel.color = [0.5, 0.5, 0.5] # Gray color
-
- print(
- f"Created voxel grid with {len(voxel_grid.get_voxels())} voxels (voxel_size={voxel_size})"
- )
- return voxel_grid
-
- except Exception as e:
- print(f"Error creating voxel grid: {e}")
- return o3d.geometry.VoxelGrid()
-
-
-def _create_voxel_grid_from_clusters(
- clusters: list[o3d.geometry.PointCloud], voxel_size: float = 0.02
-) -> o3d.geometry.VoxelGrid:
- """
- Create a voxel grid from multiple clustered point clouds.
-
- Args:
- clusters: List of clustered point clouds
- voxel_size: Size of each voxel
-
- Returns:
- Open3D VoxelGrid object
- """
- if not clusters:
- return o3d.geometry.VoxelGrid()
-
- # Combine all clusters into one point cloud
- combined_points = []
- for cluster in clusters:
- points = np.asarray(cluster.points)
- if len(points) > 0:
- combined_points.append(points)
-
- if not combined_points:
- return o3d.geometry.VoxelGrid()
-
- # Create combined point cloud
- all_points = np.vstack(combined_points)
- combined_pcd = o3d.geometry.PointCloud()
- combined_pcd.points = o3d.utility.Vector3dVector(all_points)
-
- return _create_voxel_grid_from_point_cloud(combined_pcd, voxel_size)
-
-
-def _cluster_point_cloud_dbscan(
- pcd: o3d.geometry.PointCloud, eps: float = 0.05, min_points: int = 50
-) -> list[o3d.geometry.PointCloud]:
- """
- Cluster a point cloud using DBSCAN and return list of clustered point clouds.
-
- Args:
- pcd: Point cloud to cluster
- eps: DBSCAN epsilon parameter
- min_points: DBSCAN min_samples parameter
-
- Returns:
- List of point clouds, one for each cluster
- """
- if len(np.asarray(pcd.points)) == 0:
- return []
-
- try:
- # Apply DBSCAN clustering
- labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_points))
-
- # Get unique cluster labels (excluding noise points labeled as -1)
- unique_labels = np.unique(labels)
- cluster_pcds = []
-
- for label in unique_labels:
- if label == -1: # Skip noise points
- continue
-
- # Get indices for this cluster
- cluster_indices = np.where(labels == label)[0]
-
- if len(cluster_indices) > 0:
- # Create point cloud for this cluster
- cluster_pcd = pcd.select_by_index(cluster_indices)
-
- # Assign a random color to this cluster
- cluster_color = np.random.rand(3) # Random RGB color
- cluster_pcd.paint_uniform_color(cluster_color)
-
- cluster_pcds.append(cluster_pcd)
-
- print(
- f"DBSCAN clustering found {len(cluster_pcds)} clusters from {len(np.asarray(pcd.points))} points"
- )
- return cluster_pcds
-
- except Exception as e:
- print(f"Error in DBSCAN clustering: {e}")
- return [pcd] # Return original point cloud as fallback
-
-
-def get_standard_coordinate_transform(): # type: ignore[no-untyped-def]
- """
- Get a standard coordinate transformation matrix for consistent visualization.
-
- This transformation ensures that:
- - X (red) axis points right
- - Y (green) axis points up
- - Z (blue) axis points toward viewer
-
- Returns:
- 4x4 transformation matrix
- """
- # Standard transformation matrix to ensure consistent coordinate frame orientation
- transform = np.array(
- [
- [1, 0, 0, 0], # X points right
- [0, -1, 0, 0], # Y points up (flip from OpenCV to standard)
- [0, 0, -1, 0], # Z points toward viewer (flip depth)
- [0, 0, 0, 1],
- ]
- )
- return transform
-
-
-def visualize_clustered_point_clouds(
- clustered_pcds: list[o3d.geometry.PointCloud],
- window_name: str = "Clustered Point Clouds",
- point_size: float = 2.0,
- show_coordinate_frame: bool = True,
- coordinate_frame_size: float = 0.1,
-) -> None:
- """
- Visualize multiple clustered point clouds with different colors.
-
- Args:
- clustered_pcds: List of point clouds (already colored)
- window_name: Name of the visualization window
- point_size: Size of points in the visualization
- show_coordinate_frame: Whether to show coordinate frame
- coordinate_frame_size: Size of the coordinate frame
- """
- if not clustered_pcds:
- print("Warning: No clustered point clouds to visualize")
- return
-
- # Apply standard coordinate transformation
- transform = get_standard_coordinate_transform() # type: ignore[no-untyped-call]
- geometries = []
- for pcd in clustered_pcds:
- pcd_copy = o3d.geometry.PointCloud(pcd)
- pcd_copy.transform(transform)
- geometries.append(pcd_copy)
-
- # Add coordinate frame
- if show_coordinate_frame:
- coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
- size=coordinate_frame_size
- )
- coordinate_frame.transform(transform)
- geometries.append(coordinate_frame)
-
- total_points = sum(len(np.asarray(pcd.points)) for pcd in clustered_pcds)
- print(f"Visualizing {len(clustered_pcds)} clusters with {total_points} total points")
-
- try:
- vis = o3d.visualization.Visualizer()
- vis.create_window(window_name=window_name, width=1280, height=720)
- for geom in geometries:
- vis.add_geometry(geom)
- render_option = vis.get_render_option()
- render_option.point_size = point_size
- vis.run()
- vis.destroy_window()
- except Exception as e:
- print(f"Failed to create interactive visualization: {e}")
- o3d.visualization.draw_geometries(
- geometries, window_name=window_name, width=1280, height=720
- )
-
-
-def visualize_pcd(
- pcd: o3d.geometry.PointCloud,
- window_name: str = "Point Cloud Visualization",
- point_size: float = 1.0,
- show_coordinate_frame: bool = True,
- coordinate_frame_size: float = 0.1,
-) -> None:
- """
- Visualize an Open3D point cloud using Open3D's visualization window.
-
- Args:
- pcd: Open3D point cloud to visualize
- window_name: Name of the visualization window
- point_size: Size of points in the visualization
- show_coordinate_frame: Whether to show coordinate frame
- coordinate_frame_size: Size of the coordinate frame
- """
- if pcd is None:
- print("Warning: Point cloud is None, nothing to visualize")
- return
-
- if len(np.asarray(pcd.points)) == 0:
- print("Warning: Point cloud is empty, nothing to visualize")
- return
-
- # Apply standard coordinate transformation
- transform = get_standard_coordinate_transform() # type: ignore[no-untyped-call]
- pcd_copy = o3d.geometry.PointCloud(pcd)
- pcd_copy.transform(transform)
- geometries = [pcd_copy]
-
- # Add coordinate frame
- if show_coordinate_frame:
- coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
- size=coordinate_frame_size
- )
- coordinate_frame.transform(transform)
- geometries.append(coordinate_frame)
-
- print(f"Visualizing point cloud with {len(np.asarray(pcd.points))} points")
-
- try:
- vis = o3d.visualization.Visualizer()
- vis.create_window(window_name=window_name, width=1280, height=720)
- for geom in geometries:
- vis.add_geometry(geom)
- render_option = vis.get_render_option()
- render_option.point_size = point_size
- vis.run()
- vis.destroy_window()
- except Exception as e:
- print(f"Failed to create interactive visualization: {e}")
- o3d.visualization.draw_geometries(
- geometries, window_name=window_name, width=1280, height=720
- )
-
-
-def visualize_voxel_grid(
- voxel_grid: o3d.geometry.VoxelGrid,
- window_name: str = "Voxel Grid Visualization",
- show_coordinate_frame: bool = True,
- coordinate_frame_size: float = 0.1,
-) -> None:
- """
- Visualize an Open3D voxel grid using Open3D's visualization window.
-
- Args:
- voxel_grid: Open3D voxel grid to visualize
- window_name: Name of the visualization window
- show_coordinate_frame: Whether to show coordinate frame
- coordinate_frame_size: Size of the coordinate frame
- """
- if voxel_grid is None:
- print("Warning: Voxel grid is None, nothing to visualize")
- return
-
- if len(voxel_grid.get_voxels()) == 0:
- print("Warning: Voxel grid is empty, nothing to visualize")
- return
-
- # VoxelGrid doesn't support transform, so we need to transform the source points instead
- # For now, just visualize as-is with transformed coordinate frame
- geometries = [voxel_grid]
-
- # Add coordinate frame
- if show_coordinate_frame:
- coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
- size=coordinate_frame_size
- )
- coordinate_frame.transform(get_standard_coordinate_transform()) # type: ignore[no-untyped-call]
- geometries.append(coordinate_frame)
-
- print(f"Visualizing voxel grid with {len(voxel_grid.get_voxels())} voxels")
-
- try:
- vis = o3d.visualization.Visualizer()
- vis.create_window(window_name=window_name, width=1280, height=720)
- for geom in geometries:
- vis.add_geometry(geom)
- vis.run()
- vis.destroy_window()
- except Exception as e:
- print(f"Failed to create interactive visualization: {e}")
- o3d.visualization.draw_geometries(
- geometries, window_name=window_name, width=1280, height=720
- )
-
-
-def combine_object_pointclouds(
- point_clouds: list[np.ndarray] | list[o3d.geometry.PointCloud], # type: ignore[type-arg]
- colors: list[np.ndarray] | None = None, # type: ignore[type-arg]
-) -> o3d.geometry.PointCloud:
- """
- Combine multiple point clouds into a single Open3D point cloud.
-
- Args:
- point_clouds: List of point clouds as numpy arrays or Open3D point clouds
- colors: List of colors as numpy arrays
- Returns:
- Combined Open3D point cloud
- """
- all_points = []
- all_colors = []
-
- for i, pcd in enumerate(point_clouds):
- if isinstance(pcd, np.ndarray):
- points = pcd[:, :3]
- all_points.append(points)
- if colors:
- all_colors.append(colors[i])
-
- elif isinstance(pcd, o3d.geometry.PointCloud):
- points = np.asarray(pcd.points)
- all_points.append(points)
- if pcd.has_colors():
- colors = np.asarray(pcd.colors) # type: ignore[assignment]
- all_colors.append(colors) # type: ignore[arg-type]
-
- if not all_points:
- return o3d.geometry.PointCloud()
-
- combined_pcd = o3d.geometry.PointCloud()
- combined_pcd.points = o3d.utility.Vector3dVector(np.vstack(all_points))
-
- if all_colors:
- combined_pcd.colors = o3d.utility.Vector3dVector(np.vstack(all_colors))
-
- return combined_pcd
-
-
-def extract_centroids_from_masks(
- rgb_image: np.ndarray, # type: ignore[type-arg]
- depth_image: np.ndarray, # type: ignore[type-arg]
- masks: list[np.ndarray], # type: ignore[type-arg]
- camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg]
-) -> list[dict[str, Any]]:
- """
- Extract 3D centroids and orientations from segmentation masks.
-
- Args:
- rgb_image: RGB image (H, W, 3)
- depth_image: Depth image (H, W) in meters
- masks: List of boolean masks (H, W)
- camera_intrinsics: Camera parameters as [fx, fy, cx, cy] or 3x3 matrix
-
- Returns:
- List of dictionaries containing:
- - centroid: 3D centroid position [x, y, z] in camera frame
- - orientation: Normalized direction vector from camera to centroid
- - num_points: Number of valid 3D points
- - mask_idx: Index of the mask in the input list
- """
- # Extract camera parameters
- if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4:
- fx, fy, cx, cy = camera_intrinsics
- else:
- fx = camera_intrinsics[0, 0] # type: ignore[call-overload]
- fy = camera_intrinsics[1, 1] # type: ignore[call-overload]
- cx = camera_intrinsics[0, 2] # type: ignore[call-overload]
- cy = camera_intrinsics[1, 2] # type: ignore[call-overload]
-
- results = []
-
- for mask_idx, mask in enumerate(masks):
- if mask is None or mask.sum() == 0:
- continue
-
- # Get pixel coordinates where mask is True
- y_coords, x_coords = np.where(mask)
-
- # Get depth values at mask locations
- depths = depth_image[y_coords, x_coords]
-
- # Convert to 3D points in camera frame
- X = (x_coords - cx) * depths / fx
- Y = (y_coords - cy) * depths / fy
- Z = depths
-
- # Calculate centroid
- centroid_x = np.mean(X)
- centroid_y = np.mean(Y)
- centroid_z = np.mean(Z)
- centroid = np.array([centroid_x, centroid_y, centroid_z])
-
- # Calculate orientation as normalized direction from camera origin to centroid
- # Camera origin is at (0, 0, 0)
- orientation = centroid / np.linalg.norm(centroid)
-
- results.append(
- {
- "centroid": centroid,
- "orientation": orientation,
- "num_points": int(mask.sum()),
- "mask_idx": mask_idx,
- }
- )
-
- return results
diff --git a/dimos/perception/segmentation/__init__.py b/dimos/perception/segmentation/__init__.py
deleted file mode 100644
index a48a76d6a4..0000000000
--- a/dimos/perception/segmentation/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .sam_2d_seg import *
-from .utils import *
diff --git a/dimos/perception/segmentation/config/custom_tracker.yaml b/dimos/perception/segmentation/config/custom_tracker.yaml
deleted file mode 100644
index 7a6748ebf6..0000000000
--- a/dimos/perception/segmentation/config/custom_tracker.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-# Default Ultralytics settings for BoT-SORT tracker when using mode="track"
-# For documentation and examples see https://docs.ultralytics.com/modes/track/
-# For BoT-SORT source code see https://github.com/NirAharon/BoT-SORT
-
-tracker_type: botsort # tracker type, ['botsort', 'bytetrack']
-track_high_thresh: 0.4 # threshold for the first association
-track_low_thresh: 0.2 # threshold for the second association
-new_track_thresh: 0.5 # threshold for init new track if the detection does not match any tracks
-track_buffer: 100 # buffer to calculate the time when to remove tracks
-match_thresh: 0.4 # threshold for matching tracks
-fuse_score: False # Whether to fuse confidence scores with the iou distances before matching
-# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
-
-# BoT-SORT settings
-gmc_method: sparseOptFlow # method of global motion compensation
-# ReID model related thresh (not supported yet)
-proximity_thresh: 0.6
-appearance_thresh: 0.35
-with_reid: False
diff --git a/dimos/perception/segmentation/image_analyzer.py b/dimos/perception/segmentation/image_analyzer.py
deleted file mode 100644
index 06db712ac7..0000000000
--- a/dimos/perception/segmentation/image_analyzer.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import os
-
-import cv2
-from openai import OpenAI
-
-NORMAL_PROMPT = "What are in these images? Give a short word answer with at most two words, \
- if not sure, give a description of its shape or color like 'small tube', 'blue item'. \" \
- if does not look like an object, say 'unknown'. Export objects as a list of strings \
- in this exact format '['object 1', 'object 2', '...']'."
-
-RICH_PROMPT = (
- "What are in these images? Give a detailed description of each item, the first n images will be \
- cropped patches of the original image detected by the object detection model. \
- The last image will be the original image. Use the last image only for context, \
- do not describe objects in the last image. \
- Export the objects as a list of strings in this exact format, '['description of object 1', '...', '...']', \
- don't include anything else. "
-)
-
-
-class ImageAnalyzer:
- def __init__(self) -> None:
- """
- Initializes the ImageAnalyzer with OpenAI API credentials.
- """
- self.client = OpenAI()
-
- def encode_image(self, image): # type: ignore[no-untyped-def]
- """
- Encodes an image to Base64.
-
- Parameters:
- image (numpy array): Image array (BGR format).
-
- Returns:
- str: Base64 encoded string of the image.
- """
- _, buffer = cv2.imencode(".jpg", image)
- return base64.b64encode(buffer).decode("utf-8")
-
- def analyze_images(self, images, detail: str = "auto", prompt_type: str = "normal"): # type: ignore[no-untyped-def]
- """
- Takes a list of cropped images and returns descriptions from OpenAI's Vision model.
-
- Parameters:
- images (list of numpy arrays): Cropped images from the original frame.
- detail (str): "low", "high", or "auto" to set image processing detail.
- prompt_type (str): "normal" or "rich" to set the prompt type.
-
- Returns:
- list of str: Descriptions of objects in each image.
- """
- image_data = [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{self.encode_image(img)}", # type: ignore[no-untyped-call]
- "detail": detail,
- },
- }
- for img in images
- ]
-
- if prompt_type == "normal":
- prompt = NORMAL_PROMPT
- elif prompt_type == "rich":
- prompt = RICH_PROMPT
- else:
- raise ValueError(f"Invalid prompt type: {prompt_type}")
-
- response = self.client.chat.completions.create(
- model="gpt-4o-mini",
- messages=[
- { # type: ignore[list-item, misc]
- "role": "user",
- "content": [{"type": "text", "text": prompt}, *image_data],
- }
- ],
- max_tokens=300,
- timeout=5,
- )
-
- # Accessing the content of the response using dot notation
- return next(choice.message.content for choice in response.choices)
-
-
-def main() -> None:
- # Define the directory containing cropped images
- cropped_images_dir = "cropped_images"
- if not os.path.exists(cropped_images_dir):
- print(f"Directory '{cropped_images_dir}' does not exist.")
- return
-
- # Load all images from the directory
- images = []
- for filename in os.listdir(cropped_images_dir):
- if filename.endswith(".jpg") or filename.endswith(".png"):
- image_path = os.path.join(cropped_images_dir, filename)
- image = cv2.imread(image_path)
- if image is not None:
- images.append(image)
- else:
- print(f"Warning: Could not read image {image_path}")
-
- if not images:
- print("No valid images found in the directory.")
- return
-
- # Initialize ImageAnalyzer
- analyzer = ImageAnalyzer()
-
- # Analyze images
- results = analyzer.analyze_images(images)
-
- # Split results into a list of items
- object_list = [item.strip()[2:] for item in results.split("\n")]
-
- # Overlay text on images and display them
- for i, (img, obj) in enumerate(zip(images, object_list, strict=False)):
- if obj: # Only process non-empty lines
- # Add text to image
- font = cv2.FONT_HERSHEY_SIMPLEX
- font_scale = 0.5
- thickness = 2
- text = obj.strip()
-
- # Get text size
- (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
-
- # Position text at top of image
- x = 10
- y = text_height + 10
-
- # Add white background for text
- cv2.rectangle(
- img, (x - 5, y - text_height - 5), (x + text_width + 5, y + 5), (255, 255, 255), -1
- )
- # Add text
- cv2.putText(img, text, (x, y), font, font_scale, (0, 0, 0), thickness)
-
- # Save or display the image
- cv2.imwrite(f"annotated_image_{i}.jpg", img)
- print(f"Detected object: {obj}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py
deleted file mode 100644
index 741f71a9ab..0000000000
--- a/dimos/perception/segmentation/sam_2d_seg.py
+++ /dev/null
@@ -1,366 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections import deque
-from collections.abc import Sequence
-from concurrent.futures import ThreadPoolExecutor
-import os
-import time
-
-import cv2
-import onnxruntime # type: ignore[import-untyped]
-import torch
-from ultralytics import FastSAM # type: ignore[attr-defined, import-not-found]
-
-from dimos.perception.common.detection2d_tracker import get_tracked_results, target2dTracker
-from dimos.perception.segmentation.image_analyzer import ImageAnalyzer
-from dimos.perception.segmentation.utils import (
- crop_images_from_bboxes,
- extract_masks_bboxes_probs_names,
- filter_segmentation_results,
- plot_results,
-)
-from dimos.utils.data import get_data
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class Sam2DSegmenter:
- def __init__(
- self,
- model_path: str = "models_fastsam",
- model_name: str = "FastSAM-s.onnx",
- min_analysis_interval: float = 5.0,
- use_tracker: bool = True,
- use_analyzer: bool = True,
- use_rich_labeling: bool = False,
- use_filtering: bool = True,
- ) -> None:
- # Use GPU if available, otherwise fall back to CPU
- if torch.cuda.is_available():
- logger.info("Using CUDA for SAM 2d segmenter")
- if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18
- onnxruntime.preload_dlls(cuda=True, cudnn=True)
- self.device = "cuda"
- # MacOS Metal performance shaders
- elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
- logger.info("Using Metal for SAM 2d segmenter")
- self.device = "mps"
- else:
- logger.info("Using CPU for SAM 2d segmenter")
- self.device = "cpu"
-
- # Core components
- self.model = FastSAM(get_data(model_path) / model_name)
- self.use_tracker = use_tracker
- self.use_analyzer = use_analyzer
- self.use_rich_labeling = use_rich_labeling
- self.use_filtering = use_filtering
-
- module_dir = os.path.dirname(__file__)
- self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml")
-
- # Initialize tracker if enabled
- if self.use_tracker:
- self.tracker = target2dTracker(
- history_size=80,
- score_threshold_start=0.7,
- score_threshold_stop=0.05,
- min_frame_count=10,
- max_missed_frames=50,
- min_area_ratio=0.05,
- max_area_ratio=0.4,
- texture_range=(0.0, 0.35),
- border_safe_distance=100,
- weights={"prob": 1.0, "temporal": 3.0, "texture": 2.0, "border": 3.0, "size": 1.0},
- )
-
- # Initialize analyzer components if enabled
- if self.use_analyzer:
- self.image_analyzer = ImageAnalyzer()
- self.min_analysis_interval = min_analysis_interval
- self.last_analysis_time = 0
- self.to_be_analyzed = deque() # type: ignore[var-annotated]
- self.object_names = {} # type: ignore[var-annotated]
- self.analysis_executor = ThreadPoolExecutor(max_workers=1)
- self.current_future = None
- self.current_queue_ids = None
-
- def process_image(self, image): # type: ignore[no-untyped-def]
- """Process an image and return segmentation results."""
- results = self.model.track(
- source=image,
- device=self.device,
- retina_masks=True,
- conf=0.3,
- iou=0.5,
- persist=True,
- verbose=False,
- )
-
- if len(results) > 0:
- # Get initial segmentation results
- masks, bboxes, track_ids, probs, names, areas = extract_masks_bboxes_probs_names(
- results[0]
- )
-
- # Filter results
- if self.use_filtering:
- (
- filtered_masks,
- filtered_bboxes,
- filtered_track_ids,
- filtered_probs,
- filtered_names,
- filtered_texture_values,
- ) = filter_segmentation_results(
- image, masks, bboxes, track_ids, probs, names, areas
- )
- else:
- # Use original results without filtering
- filtered_masks = masks
- filtered_bboxes = bboxes
- filtered_track_ids = track_ids
- filtered_probs = probs
- filtered_names = names
- filtered_texture_values = []
-
- if self.use_tracker:
- # Update tracker with filtered results
- tracked_targets = self.tracker.update(
- image,
- filtered_masks,
- filtered_bboxes,
- filtered_track_ids,
- filtered_probs,
- filtered_names,
- filtered_texture_values,
- )
-
- # Get tracked results
- tracked_masks, tracked_bboxes, tracked_target_ids, tracked_probs, tracked_names = (
- get_tracked_results(tracked_targets) # type: ignore[no-untyped-call]
- )
-
- if self.use_analyzer:
- # Update analysis queue with tracked IDs
- target_id_set = set(tracked_target_ids)
-
- # Remove untracked objects from object_names
- all_target_ids = list(self.tracker.targets.keys())
- self.object_names = {
- track_id: name
- for track_id, name in self.object_names.items()
- if track_id in all_target_ids
- }
-
- # Remove untracked objects from queue and results
- self.to_be_analyzed = deque(
- [track_id for track_id in self.to_be_analyzed if track_id in target_id_set]
- )
-
- # Filter out any IDs being analyzed from the to_be_analyzed queue
- if self.current_queue_ids:
- self.to_be_analyzed = deque(
- [
- tid
- for tid in self.to_be_analyzed
- if tid not in self.current_queue_ids
- ]
- )
-
- # Add new track_ids to analysis queue
- for track_id in tracked_target_ids:
- if (
- track_id not in self.object_names
- and track_id not in self.to_be_analyzed
- ):
- self.to_be_analyzed.append(track_id)
-
- return (
- tracked_masks,
- tracked_bboxes,
- tracked_target_ids,
- tracked_probs,
- tracked_names,
- )
- else:
- # When tracker disabled, just use the filtered results directly
- if self.use_analyzer:
- # Add unanalyzed IDs to the analysis queue
- for track_id in filtered_track_ids:
- if (
- track_id not in self.object_names
- and track_id not in self.to_be_analyzed
- ):
- self.to_be_analyzed.append(track_id)
-
- # Simply return filtered results
- return (
- filtered_masks,
- filtered_bboxes,
- filtered_track_ids,
- filtered_probs,
- filtered_names,
- )
- return [], [], [], [], []
-
- def check_analysis_status(self, tracked_target_ids): # type: ignore[no-untyped-def]
- """Check if analysis is complete and prepare new queue if needed."""
- if not self.use_analyzer:
- return None, None
-
- current_time = time.time()
-
- # Check if current queue analysis is complete
- if self.current_future and self.current_future.done():
- try:
- results = self.current_future.result()
- if results is not None:
- # Map results to track IDs
- object_list = eval(results)
- for track_id, result in zip(self.current_queue_ids, object_list, strict=False):
- self.object_names[track_id] = result
- except Exception as e:
- print(f"Queue analysis failed: {e}")
- self.current_future = None
- self.current_queue_ids = None
- self.last_analysis_time = current_time
-
- # If enough time has passed and we have items to analyze, start new analysis
- if (
- not self.current_future
- and self.to_be_analyzed
- and current_time - self.last_analysis_time >= self.min_analysis_interval
- ):
- queue_indices = []
- queue_ids = []
-
- # Collect all valid track IDs from the queue
- while self.to_be_analyzed:
- track_id = self.to_be_analyzed[0]
- if track_id in tracked_target_ids:
- bbox_idx = tracked_target_ids.index(track_id)
- queue_indices.append(bbox_idx)
- queue_ids.append(track_id)
- self.to_be_analyzed.popleft()
-
- if queue_indices:
- return queue_indices, queue_ids
- return None, None
-
- def run_analysis(self, frame, tracked_bboxes, tracked_target_ids) -> None: # type: ignore[no-untyped-def]
- """Run queue image analysis in background."""
- if not self.use_analyzer:
- return
-
- queue_indices, queue_ids = self.check_analysis_status(tracked_target_ids) # type: ignore[no-untyped-call]
- if queue_indices:
- selected_bboxes = [tracked_bboxes[i] for i in queue_indices]
- cropped_images = crop_images_from_bboxes(frame, selected_bboxes)
- if cropped_images:
- self.current_queue_ids = queue_ids
- print(f"Analyzing objects with track_ids: {queue_ids}")
-
- if self.use_rich_labeling:
- prompt_type = "rich"
- cropped_images.append(frame)
- else:
- prompt_type = "normal"
-
- self.current_future = self.analysis_executor.submit( # type: ignore[assignment]
- self.image_analyzer.analyze_images, cropped_images, prompt_type=prompt_type
- )
-
- def get_object_names(self, track_ids, tracked_names: Sequence[str]): # type: ignore[no-untyped-def]
- """Get object names for the given track IDs, falling back to tracked names."""
- if not self.use_analyzer:
- return tracked_names
-
- return [
- self.object_names.get(track_id, tracked_name)
- for track_id, tracked_name in zip(track_ids, tracked_names, strict=False)
- ]
-
- def visualize_results( # type: ignore[no-untyped-def]
- self, image, masks, bboxes, track_ids, probs: Sequence[float], names: Sequence[str]
- ):
- """Generate an overlay visualization with segmentation results and object names."""
- return plot_results(image, masks, bboxes, track_ids, probs, names)
-
- def cleanup(self) -> None:
- """Cleanup resources."""
- if self.use_analyzer:
- self.analysis_executor.shutdown()
-
-
-def main() -> None:
- # Example usage with different configurations
- cap = cv2.VideoCapture(0)
-
- # Example 1: Full functionality with rich labeling
- segmenter = Sam2DSegmenter(
- min_analysis_interval=4.0,
- use_tracker=True,
- use_analyzer=True,
- use_rich_labeling=True, # Enable rich labeling
- )
-
- # Example 2: Full functionality with normal labeling
- # segmenter = Sam2DSegmenter(min_analysis_interval=4.0, use_tracker=True, use_analyzer=True)
-
- # Example 3: Tracker only (analyzer disabled)
- # segmenter = Sam2DSegmenter(use_analyzer=False)
-
- # Example 4: Basic segmentation only (both tracker and analyzer disabled)
- # segmenter = Sam2DSegmenter(use_tracker=False, use_analyzer=False)
-
- # Example 5: Analyzer without tracker (new capability)
- # segmenter = Sam2DSegmenter(use_tracker=False, use_analyzer=True)
-
- try:
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret:
- break
-
- time.time()
-
- # Process image and get results
- masks, bboxes, target_ids, probs, names = segmenter.process_image(frame) # type: ignore[no-untyped-call]
-
- # Run analysis if enabled
- if segmenter.use_analyzer:
- segmenter.run_analysis(frame, bboxes, target_ids)
- names = segmenter.get_object_names(target_ids, names)
-
- # processing_time = time.time() - start_time
- # print(f"Processing time: {processing_time:.2f}s")
-
- overlay = segmenter.visualize_results(frame, masks, bboxes, target_ids, probs, names)
-
- cv2.imshow("Segmentation", overlay)
- key = cv2.waitKey(1)
- if key & 0xFF == ord("q"):
- break
-
- finally:
- segmenter.cleanup()
- cap.release()
- cv2.destroyAllWindows()
-
-
-if __name__ == "__main__":
- main()
diff --git a/dimos/perception/segmentation/test_sam_2d_seg.py b/dimos/perception/segmentation/test_sam_2d_seg.py
deleted file mode 100644
index a9222ed2f2..0000000000
--- a/dimos/perception/segmentation/test_sam_2d_seg.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import time
-
-import numpy as np
-import pytest
-from reactivex import operators as ops
-
-from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter
-from dimos.perception.segmentation.utils import extract_masks_bboxes_probs_names
-from dimos.stream.video_provider import VideoProvider
-
-
-@pytest.mark.heavy
-class TestSam2DSegmenter:
- def test_sam_segmenter_initialization(self) -> None:
- """Test FastSAM segmenter initializes correctly with default model path."""
- try:
- # Try to initialize with the default model path and existing device setting
- segmenter = Sam2DSegmenter(use_analyzer=False)
- assert segmenter is not None
- assert segmenter.model is not None
- except Exception as e:
- # If the model file doesn't exist, the test should still pass with a warning
- pytest.skip(f"Skipping test due to model initialization error: {e}")
-
- def test_sam_segmenter_process_image(self) -> None:
- """Test FastSAM segmenter can process video frames and return segmentation masks."""
- # Import get data inside method to avoid pytest fixture confusion
- from dimos.utils.data import get_data
-
- # Get test video path directly
- video_path = get_data("assets") / "trimmed_video_office.mov"
- try:
- # Initialize segmenter without analyzer for faster testing
- segmenter = Sam2DSegmenter(use_analyzer=False)
-
- # Note: conf and iou are parameters for process_image, not constructor
- # We'll monkey patch the process_image method to use lower thresholds
-
- def patched_process_image(image):
- results = segmenter.model.track(
- source=image,
- device=segmenter.device,
- retina_masks=True,
- conf=0.1, # Lower confidence threshold for testing
- iou=0.5, # Lower IoU threshold
- persist=True,
- verbose=False,
- tracker=segmenter.tracker_config
- if hasattr(segmenter, "tracker_config")
- else None,
- )
-
- if len(results) > 0:
- masks, bboxes, track_ids, probs, names, _areas = (
- extract_masks_bboxes_probs_names(results[0])
- )
- return masks, bboxes, track_ids, probs, names
- return [], [], [], [], []
-
- # Replace the method
- segmenter.process_image = patched_process_image
-
- # Create video provider and directly get a video stream observable
- assert os.path.exists(video_path), f"Test video not found: {video_path}"
- video_provider = VideoProvider(dev_name="test_video", video_source=video_path)
-
- video_stream = video_provider.capture_video_as_observable(realtime=False, fps=1)
-
- # Use ReactiveX operators to process the stream
- def process_frame(frame):
- try:
- # Process frame with FastSAM
- masks, bboxes, track_ids, probs, names = segmenter.process_image(frame)
- print(
- f"SAM results - masks: {len(masks)}, bboxes: {len(bboxes)}, track_ids: {len(track_ids)}, names: {len(names)}"
- )
-
- return {
- "frame": frame,
- "masks": masks,
- "bboxes": bboxes,
- "track_ids": track_ids,
- "probs": probs,
- "names": names,
- }
- except Exception as e:
- print(f"Error in process_frame: {e}")
- return {}
-
- # Create the segmentation stream using pipe and map operator
- segmentation_stream = video_stream.pipe(ops.map(process_frame))
-
- # Collect results from the stream
- results = []
- frames_processed = 0
- target_frames = 5
-
- def on_next(result) -> None:
- nonlocal frames_processed, results
- if not result:
- return
-
- results.append(result)
- frames_processed += 1
-
- # Stop processing after target frames
- if frames_processed >= target_frames:
- subscription.dispose()
-
- def on_error(error) -> None:
- pytest.fail(f"Error in segmentation stream: {error}")
-
- def on_completed() -> None:
- pass
-
- # Subscribe and wait for results
- subscription = segmentation_stream.subscribe(
- on_next=on_next, on_error=on_error, on_completed=on_completed
- )
-
- # Wait for frames to be processed
- timeout = 30.0 # seconds
- start_time = time.time()
- while frames_processed < target_frames and time.time() - start_time < timeout:
- time.sleep(0.5)
-
- # Clean up subscription
- subscription.dispose()
- video_provider.dispose_all()
-
- # Check if we have results
- if len(results) == 0:
- pytest.skip(
- "No segmentation results found, but test connection established correctly"
- )
- return
-
- print(f"Processed {len(results)} frames with segmentation results")
-
- # Analyze the first result
- result = results[0]
-
- # Check that we have a frame
- assert "frame" in result, "Result doesn't contain a frame"
- assert isinstance(result["frame"], np.ndarray), "Frame is not a numpy array"
-
- # Check that segmentation results are valid
- assert isinstance(result["masks"], list)
- assert isinstance(result["bboxes"], list)
- assert isinstance(result["track_ids"], list)
- assert isinstance(result["probs"], list)
- assert isinstance(result["names"], list)
-
- # All result lists should be the same length
- assert (
- len(result["masks"])
- == len(result["bboxes"])
- == len(result["track_ids"])
- == len(result["probs"])
- == len(result["names"])
- )
-
- # If we have masks, check that they have valid shape
- if result.get("masks") and len(result["masks"]) > 0:
- assert result["masks"][0].shape == (
- result["frame"].shape[0],
- result["frame"].shape[1],
- ), "Mask shape should match image dimensions"
- print(f"Found {len(result['masks'])} masks in first frame")
- else:
- print("No masks found in first frame, but test connection established correctly")
-
- # Test visualization function
- if result["masks"]:
- vis_frame = segmenter.visualize_results(
- result["frame"],
- result["masks"],
- result["bboxes"],
- result["track_ids"],
- result["probs"],
- result["names"],
- )
- assert isinstance(vis_frame, np.ndarray), "Visualization output should be an image"
- assert vis_frame.shape == result["frame"].shape, (
- "Visualization should have same dimensions as input frame"
- )
-
- # We've already tested visualization above, so no need for a duplicate test
-
- except Exception as e:
- pytest.skip(f"Skipping test due to error: {e}")
-
-
-if __name__ == "__main__":
- pytest.main(["-v", __file__])
diff --git a/dimos/perception/segmentation/utils.py b/dimos/perception/segmentation/utils.py
deleted file mode 100644
index a23a256ca2..0000000000
--- a/dimos/perception/segmentation/utils.py
+++ /dev/null
@@ -1,343 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections.abc import Sequence
-
-import cv2
-import numpy as np
-import torch
-
-
-class SimpleTracker:
- def __init__(
- self, history_size: int = 100, min_count: int = 10, count_window: int = 20
- ) -> None:
- """
- Simple temporal tracker that counts appearances in a fixed window.
- :param history_size: Number of past frames to remember
- :param min_count: Minimum number of appearances required
- :param count_window: Number of latest frames to consider for counting
- """
- self.history = [] # type: ignore[var-annotated]
- self.history_size = history_size
- self.min_count = min_count
- self.count_window = count_window
- self.total_counts = {} # type: ignore[var-annotated]
-
- def update(self, track_ids): # type: ignore[no-untyped-def]
- # Add new frame's track IDs to history
- self.history.append(track_ids)
- if len(self.history) > self.history_size:
- self.history.pop(0)
-
- # Consider only the latest `count_window` frames for counting
- recent_history = self.history[-self.count_window :]
- all_tracks = np.concatenate(recent_history) if recent_history else np.array([])
-
- # Compute occurrences efficiently using numpy
- unique_ids, counts = np.unique(all_tracks, return_counts=True)
- id_counts = dict(zip(unique_ids, counts, strict=False))
-
- # Update total counts but ensure it only contains IDs within the history size
- total_tracked_ids = np.concatenate(self.history) if self.history else np.array([])
- unique_total_ids, total_counts = np.unique(total_tracked_ids, return_counts=True)
- self.total_counts = dict(zip(unique_total_ids, total_counts, strict=False))
-
- # Return IDs that appear often enough
- return [track_id for track_id, count in id_counts.items() if count >= self.min_count]
-
- def get_total_counts(self): # type: ignore[no-untyped-def]
- """Returns the total count of each tracking ID seen over time, limited to history size."""
- return self.total_counts
-
-
-def extract_masks_bboxes_probs_names(result, max_size: float = 0.7): # type: ignore[no-untyped-def]
- """
- Extracts masks, bounding boxes, probabilities, and class names from one Ultralytics result object.
-
- Parameters:
- result: Ultralytics result object
- max_size: float, maximum allowed size of object relative to image (0-1)
-
- Returns:
- tuple: (masks, bboxes, track_ids, probs, names, areas)
- """
- masks = [] # type: ignore[var-annotated]
- bboxes = [] # type: ignore[var-annotated]
- track_ids = [] # type: ignore[var-annotated]
- probs = [] # type: ignore[var-annotated]
- names = [] # type: ignore[var-annotated]
- areas = [] # type: ignore[var-annotated]
-
- if result.masks is None:
- return masks, bboxes, track_ids, probs, names, areas
-
- total_area = result.masks.orig_shape[0] * result.masks.orig_shape[1]
-
- for box, mask_data in zip(result.boxes, result.masks.data, strict=False):
- mask_numpy = mask_data
-
- # Extract bounding box
- x1, y1, x2, y2 = box.xyxy[0].tolist()
-
- # Extract track_id if available
- track_id = -1 # default if no tracking
- if hasattr(box, "id") and box.id is not None:
- track_id = int(box.id[0].item())
-
- # Extract probability and class index
- conf = float(box.conf[0])
- cls_idx = int(box.cls[0])
- area = (x2 - x1) * (y2 - y1)
-
- if area / total_area > max_size:
- continue
-
- masks.append(mask_numpy)
- bboxes.append([x1, y1, x2, y2])
- track_ids.append(track_id)
- probs.append(conf)
- names.append(result.names[cls_idx])
- areas.append(area)
-
- return masks, bboxes, track_ids, probs, names, areas
-
-
-def compute_texture_map(frame, blur_size: int = 3): # type: ignore[no-untyped-def]
- """
- Compute texture map using gradient statistics.
- Returns high values for textured regions and low values for smooth regions.
-
- Parameters:
- frame: BGR image
- blur_size: Size of Gaussian blur kernel for pre-processing
-
- Returns:
- numpy array: Texture map with values normalized to [0,1]
- """
- # Convert to grayscale
- if len(frame.shape) == 3:
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
- else:
- gray = frame
-
- # Pre-process with slight blur to reduce noise
- if blur_size > 0:
- gray = cv2.GaussianBlur(gray, (blur_size, blur_size), 0)
-
- # Compute gradients in x and y directions
- grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
- grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
-
- # Compute gradient magnitude and direction
- magnitude = np.sqrt(grad_x**2 + grad_y**2)
-
- # Compute local standard deviation of gradient magnitude
- texture_map = cv2.GaussianBlur(magnitude, (15, 15), 0)
-
- # Normalize to [0,1]
- texture_map = (texture_map - texture_map.min()) / (texture_map.max() - texture_map.min() + 1e-8)
-
- return texture_map
-
-
-def filter_segmentation_results( # type: ignore[no-untyped-def]
- frame,
- masks,
- bboxes,
- track_ids,
- probs: Sequence[float],
- names: Sequence[str],
- areas,
- texture_threshold: float = 0.07,
- size_filter: int = 800,
-):
- """
- Filters segmentation results using both overlap and saliency detection.
- Uses mask_sum tensor for efficient overlap detection.
-
- Parameters:
- masks: list of torch.Tensor containing mask data
- bboxes: list of bounding boxes [x1, y1, x2, y2]
- track_ids: list of tracking IDs
- probs: list of confidence scores
- names: list of class names
- areas: list of object areas
- frame: BGR image for computing saliency
- texture_threshold: Average texture value required for mask to be kept
- size_filter: Minimum size of the object to be kept
-
- Returns:
- tuple: (filtered_masks, filtered_bboxes, filtered_track_ids, filtered_probs, filtered_names, filtered_texture_values, texture_map)
- """
- if len(masks) <= 1:
- return masks, bboxes, track_ids, probs, names, []
-
- # Compute texture map once and convert to tensor
- texture_map = compute_texture_map(frame)
-
- # Sort by area (smallest to largest)
- sorted_indices = torch.tensor(areas).argsort(descending=False)
-
- device = masks[0].device # Get the device of the first mask
-
- # Create mask_sum tensor where each pixel stores the index of the mask that claims it
- mask_sum = torch.zeros_like(masks[0], dtype=torch.int32)
-
- texture_map = torch.from_numpy(texture_map).to(
- device
- ) # Convert texture_map to tensor and move to device
-
- filtered_texture_values = [] # List to store texture values of filtered masks
-
- for i, idx in enumerate(sorted_indices):
- mask = masks[idx]
- # Compute average texture value within mask
- texture_value = torch.mean(texture_map[mask > 0]) if torch.any(mask > 0) else 0
-
- # Only claim pixels if mask passes texture threshold
- if texture_value >= texture_threshold:
- mask_sum[mask > 0] = i
- filtered_texture_values.append(
- texture_value.item() # type: ignore[union-attr]
- ) # Store the texture value as a Python float
-
- # Get indices that appear in mask_sum (these are the masks we want to keep)
- keep_indices, counts = torch.unique(mask_sum[mask_sum > 0], return_counts=True)
- size_indices = counts > size_filter
- keep_indices = keep_indices[size_indices]
-
- sorted_indices = sorted_indices.cpu()
- keep_indices = keep_indices.cpu()
-
- # Map back to original indices and filter
- final_indices = sorted_indices[keep_indices].tolist()
-
- filtered_masks = [masks[i] for i in final_indices]
- filtered_bboxes = [bboxes[i] for i in final_indices]
- filtered_track_ids = [track_ids[i] for i in final_indices]
- filtered_probs = [probs[i] for i in final_indices]
- filtered_names = [names[i] for i in final_indices]
-
- return (
- filtered_masks,
- filtered_bboxes,
- filtered_track_ids,
- filtered_probs,
- filtered_names,
- filtered_texture_values,
- )
-
-
-def plot_results( # type: ignore[no-untyped-def]
- image,
- masks,
- bboxes,
- track_ids,
- probs: Sequence[float],
- names: Sequence[str],
- alpha: float = 0.5,
-):
- """
- Draws bounding boxes, masks, and labels on the given image with enhanced visualization.
- Includes object names in the overlay and improved text visibility.
- """
- h, w = image.shape[:2]
- overlay = image.copy()
-
- for mask, bbox, track_id, prob, name in zip(
- masks, bboxes, track_ids, probs, names, strict=False
- ):
- # Convert mask tensor to numpy if needed
- if isinstance(mask, torch.Tensor):
- mask = mask.cpu().numpy()
-
- # Ensure mask is in proper format for OpenCV resize
- if mask.dtype == bool:
- mask = mask.astype(np.uint8)
- elif mask.dtype != np.uint8 and mask.dtype != np.float32:
- mask = mask.astype(np.float32)
-
- mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
-
- # Generate consistent color based on track_id
- if track_id != -1:
- np.random.seed(track_id)
- color = np.random.randint(0, 255, (3,), dtype=np.uint8)
- np.random.seed(None)
- else:
- color = np.random.randint(0, 255, (3,), dtype=np.uint8)
-
- # Apply mask color
- overlay[mask_resized > 0.5] = color
-
- # Draw bounding box
- x1, y1, x2, y2 = map(int, bbox)
- cv2.rectangle(overlay, (x1, y1), (x2, y2), color.tolist(), 2)
-
- # Prepare label text
- label = f"ID:{track_id} {prob:.2f}"
- if name: # Add object name if available
- label += f" {name}"
-
- # Calculate text size for background rectangle
- (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
-
- # Draw background rectangle for text
- cv2.rectangle(overlay, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1)
-
- # Draw text with white color for better visibility
- cv2.putText(
- overlay,
- label,
- (x1 + 2, y1 - 5),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.5,
- (255, 255, 255), # White text
- 1,
- )
-
- # Blend overlay with original image
- result = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
- return result
-
-
-def crop_images_from_bboxes(image, bboxes, buffer: int = 0): # type: ignore[no-untyped-def]
- """
- Crops regions from an image based on bounding boxes with an optional buffer.
-
- Parameters:
- image (numpy array): Input image.
- bboxes (list of lists): List of bounding boxes [x1, y1, x2, y2].
- buffer (int): Number of pixels to expand each bounding box.
-
- Returns:
- list of numpy arrays: Cropped image regions.
- """
- height, width, _ = image.shape
- cropped_images = []
-
- for bbox in bboxes:
- x1, y1, x2, y2 = bbox
-
- # Apply buffer
- x1 = max(0, x1 - buffer)
- y1 = max(0, y1 - buffer)
- x2 = min(width, x2 + buffer)
- y2 = min(height, y2 + buffer)
-
- cropped_image = image[int(y1) : int(y2), int(x1) : int(x2)]
- cropped_images.append(cropped_image)
-
- return cropped_images
diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py
index 013d242ba8..e33820f22c 100644
--- a/dimos/perception/spatial_perception.py
+++ b/dimos/perception/spatial_perception.py
@@ -32,7 +32,8 @@
from dimos.agents_deprecated.memory.spatial_vector_db import SpatialVectorDB
from dimos.agents_deprecated.memory.visual_memory import VisualMemory
from dimos.constants import DIMOS_PROJECT_ROOT
-from dimos.core import DimosCluster, In, Module, rpc
+from dimos.core import DimosCluster, In, rpc
+from dimos.core.skill_module import SkillModule
from dimos.msgs.sensor_msgs import Image
from dimos.types.robot_location import RobotLocation
from dimos.utils.logging_config import setup_logger
@@ -50,7 +51,7 @@
logger = setup_logger()
-class SpatialMemory(Module):
+class SpatialMemory(SkillModule):
"""
A Dask module for building and querying Robot spatial memory.
@@ -216,8 +217,12 @@ def stop(self) -> None:
def _process_frame(self) -> None:
"""Process the latest frame with pose data if available."""
- tf = self.tf.get("map", "base_link")
- if self._latest_video_frame is None or tf is None:
+ tf = self.tf.get("world", "base_link")
+
+ if tf is None:
+ return
+
+ if self._latest_video_frame is None:
return
# Create Pose object with position and orientation
@@ -501,7 +506,7 @@ def add_named_location(
Returns:
True if successfully added, False otherwise
"""
- tf = self.tf.get("map", "base_link")
+ tf = self.tf.get("world", "base_link")
if not tf:
logger.error("No position available for robot location")
return False
diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py
index 48a2b2750f..47518b889b 100644
--- a/dimos/perception/test_spatial_memory_module.py
+++ b/dimos/perception/test_spatial_memory_module.py
@@ -110,6 +110,7 @@ def stop(self) -> None:
@pytest.mark.gpu
+@pytest.mark.neverending
class TestSpatialMemoryModule:
@pytest.fixture(scope="function")
def temp_dir(self):
diff --git a/dimos/protocol/mcp/test_mcp_module.py b/dimos/protocol/mcp/test_mcp_module.py
index 1deb5b9057..2a247e6ff0 100644
--- a/dimos/protocol/mcp/test_mcp_module.py
+++ b/dimos/protocol/mcp/test_mcp_module.py
@@ -21,7 +21,6 @@
import socket
import subprocess
import sys
-import threading
import pytest
@@ -134,6 +133,7 @@ async def wait_for_updates(self) -> bool:
assert "Error:" in response["result"]["content"][0]["text"]
+@pytest.mark.integration
def test_mcp_end_to_end_lcm_bridge() -> None:
try:
import lcm # type: ignore[import-untyped]
diff --git a/dimos/protocol/pubsub/benchmark/test_benchmark.py b/dimos/protocol/pubsub/benchmark/test_benchmark.py
new file mode 100644
index 0000000000..865c4ee324
--- /dev/null
+++ b/dimos/protocol/pubsub/benchmark/test_benchmark.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python3
+
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Generator
+import threading
+import time
+from typing import Any
+
+import pytest
+
+from dimos.protocol.pubsub.benchmark.testdata import testcases
+from dimos.protocol.pubsub.benchmark.type import (
+ BenchmarkResult,
+ BenchmarkResults,
+ Case,
+ MsgGen,
+ PubSubContext,
+)
+
+# Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB)
+MSG_SIZES = [
+ 64,
+ 256,
+ 1024,
+ 4096,
+ 16384,
+ 65536,
+ 262144,
+ 524288,
+ 1048576,
+ 1048576 * 2,
+ 1048576 * 5,
+ 1048576 * 10,
+]
+
+# Benchmark duration in seconds
+BENCH_DURATION = 1.0
+
+# Max messages to send per test (prevents overwhelming slower transports)
+MAX_MESSAGES = 5000
+
+# Max time to wait for in-flight messages after publishing stops
+RECEIVE_TIMEOUT = 1.0
+
+
+def size_id(size: int) -> str:
+ """Convert byte size to human-readable string for test IDs."""
+ if size >= 1048576:
+ return f"{size // 1048576}MB"
+ if size >= 1024:
+ return f"{size // 1024}KB"
+ return f"{size}B"
+
+
+def pubsub_id(testcase: Case[Any, Any]) -> str:
+ """Extract pubsub implementation name from context manager function name."""
+ name: str = testcase.pubsub_context.__name__
+ # Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory"
+ prefix = name.replace("_pubsub_channel", "").replace("_", " ")
+ return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "")
+
+
+@pytest.fixture(scope="module")
+def benchmark_results() -> Generator[BenchmarkResults, None, None]:
+ """Module-scoped fixture to collect benchmark results."""
+ results = BenchmarkResults()
+ yield results
+ results.print_summary()
+ results.print_heatmap()
+ results.print_bandwidth_heatmap()
+ results.print_latency_heatmap()
+
+
+@pytest.mark.tool
+@pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES])
+@pytest.mark.parametrize("pubsub_context, msggen", testcases, ids=[pubsub_id(t) for t in testcases])
+def test_throughput(
+ pubsub_context: PubSubContext[Any, Any],
+ msggen: MsgGen[Any, Any],
+ msg_size: int,
+ benchmark_results: BenchmarkResults,
+) -> None:
+ """Measure throughput for publishing and receiving messages over a fixed duration."""
+ with pubsub_context() as pubsub:
+ topic, msg = msggen(msg_size)
+ received_count = 0
+ target_count = [0] # Use list to allow modification after publish loop
+ lock = threading.Lock()
+ all_received = threading.Event()
+
+ def callback(message: Any, _topic: Any) -> None:
+ nonlocal received_count
+ with lock:
+ received_count += 1
+ if target_count[0] > 0 and received_count >= target_count[0]:
+ all_received.set()
+
+ # Subscribe
+ pubsub.subscribe(topic, callback)
+
+ # Warmup: give DDS/ROS time to establish connection
+ time.sleep(0.1)
+
+ # Set target so callback can signal when all received
+ target_count[0] = MAX_MESSAGES
+
+ # Publish messages until time limit, max messages, or all received
+ msgs_sent = 0
+ start = time.perf_counter()
+ end_time = start + BENCH_DURATION
+
+ while time.perf_counter() < end_time and msgs_sent < MAX_MESSAGES:
+ pubsub.publish(topic, msg)
+ msgs_sent += 1
+ # Check if all already received (fast transports)
+ if all_received.is_set():
+ break
+
+ publish_end = time.perf_counter()
+ target_count[0] = msgs_sent # Update to actual sent count
+
+ # Check if already done, otherwise wait up to RECEIVE_TIMEOUT
+ with lock:
+ if received_count >= msgs_sent:
+ all_received.set()
+
+ if not all_received.is_set():
+ all_received.wait(timeout=RECEIVE_TIMEOUT)
+ latency_end = time.perf_counter()
+
+ with lock:
+ final_received = received_count
+
+ # Latency: how long we waited after publishing for messages to arrive
+ # 0 = all arrived during publishing, 1000ms = hit timeout (loss occurred)
+ latency = latency_end - publish_end
+
+ # Record result (duration is publish time only for throughput calculation)
+ # Extract transport name from context manager function name
+ ctx_name = pubsub_context.__name__
+ prefix = ctx_name.replace("_pubsub_channel", "").replace("_", " ")
+ transport_name = prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "")
+ result = BenchmarkResult(
+ transport=transport_name,
+ duration=publish_end - start,
+ msgs_sent=msgs_sent,
+ msgs_received=final_received,
+ msg_size_bytes=msg_size,
+ receive_time=latency,
+ )
+ benchmark_results.add(result)
+
+ # Warn if significant message loss (but don't fail - benchmark records the data)
+ loss_pct = (1 - final_received / msgs_sent) * 100 if msgs_sent > 0 else 0
+ if loss_pct > 10:
+ import warnings
+
+ warnings.warn(
+ f"{transport_name} {msg_size}B: {loss_pct:.1f}% message loss "
+ f"({final_received}/{msgs_sent})",
+ stacklevel=2,
+ )
diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py
new file mode 100644
index 0000000000..beb140227f
--- /dev/null
+++ b/dimos/protocol/pubsub/benchmark/testdata.py
@@ -0,0 +1,269 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Generator
+from contextlib import contextmanager
+from typing import Any
+
+import numpy as np
+
+from dimos.msgs.sensor_msgs.Image import Image, ImageFormat
+from dimos.protocol.pubsub.benchmark.type import Case
+from dimos.protocol.pubsub.lcmpubsub import LCM, LCMPubSubBase, Topic as LCMTopic
+from dimos.protocol.pubsub.memory import Memory
+from dimos.protocol.pubsub.shmpubsub import BytesSharedMemory, LCMSharedMemory, PickleSharedMemory
+
+
+def make_data_bytes(size: int) -> bytes:
+ """Generate random bytes of given size."""
+ return bytes(i % 256 for i in range(size))
+
+
+def make_data_image(size: int) -> Image:
+ """Generate an RGB Image with approximately `size` bytes of data."""
+ raw_data = np.frombuffer(make_data_bytes(size), dtype=np.uint8).reshape(-1)
+ # Pad to make it divisible by 3 for RGB
+ padded_size = ((len(raw_data) + 2) // 3) * 3
+ padded_data = np.pad(raw_data, (0, padded_size - len(raw_data)))
+ pixels = len(padded_data) // 3
+ # Find reasonable dimensions
+ height = max(1, int(pixels**0.5))
+ width = pixels // height
+ data = padded_data[: height * width * 3].reshape(height, width, 3)
+ return Image(data=data, format=ImageFormat.RGB)
+
+
+testcases: list[Case[Any, Any]] = []
+
+
+@contextmanager
+def lcm_pubsub_channel() -> Generator[LCM, None, None]:
+ lcm_pubsub = LCM(autoconf=True)
+ lcm_pubsub.start()
+ yield lcm_pubsub
+ lcm_pubsub.stop()
+
+
+def lcm_msggen(size: int) -> tuple[LCMTopic, Image]:
+ topic = LCMTopic(topic="benchmark/lcm", lcm_type=Image)
+ return (topic, make_data_image(size))
+
+
+testcases.append(
+ Case(
+ pubsub_context=lcm_pubsub_channel,
+ msg_gen=lcm_msggen,
+ )
+)
+
+
+@contextmanager
+def udp_bytes_pubsub_channel() -> Generator[LCMPubSubBase, None, None]:
+ """LCM with raw bytes - no encoding overhead."""
+ lcm_pubsub = LCMPubSubBase(autoconf=True)
+ lcm_pubsub.start()
+ yield lcm_pubsub
+ lcm_pubsub.stop()
+
+
+def udp_bytes_msggen(size: int) -> tuple[LCMTopic, bytes]:
+ """Generate raw bytes for LCM transport benchmark."""
+ topic = LCMTopic(topic="benchmark/lcm_raw")
+ return (topic, make_data_bytes(size))
+
+
+testcases.append(
+ Case(
+ pubsub_context=udp_bytes_pubsub_channel,
+ msg_gen=udp_bytes_msggen,
+ )
+)
+
+
+@contextmanager
+def memory_pubsub_channel() -> Generator[Memory, None, None]:
+ """Context manager for Memory PubSub implementation."""
+ yield Memory()
+
+
+def memory_msggen(size: int) -> tuple[str, Any]:
+ return ("benchmark/memory", make_data_image(size))
+
+
+# testcases.append(
+# Case(
+# pubsub_context=memory_pubsub_channel,
+# msg_gen=memory_msggen,
+# )
+# )
+
+
+@contextmanager
+def shm_pickle_pubsub_channel() -> Generator[PickleSharedMemory, None, None]:
+ # 12MB capacity to handle benchmark sizes up to 10MB
+ shm_pubsub = PickleSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024)
+ shm_pubsub.start()
+ yield shm_pubsub
+ shm_pubsub.stop()
+
+
+def shm_msggen(size: int) -> tuple[str, Any]:
+ """Generate message for SharedMemory pubsub benchmark."""
+ return ("benchmark/shm", make_data_image(size))
+
+
+testcases.append(
+ Case(
+ pubsub_context=shm_pickle_pubsub_channel,
+ msg_gen=shm_msggen,
+ )
+)
+
+
+@contextmanager
+def shm_bytes_pubsub_channel() -> Generator[BytesSharedMemory, None, None]:
+ """SharedMemory with raw bytes - no pickle overhead."""
+ shm_pubsub = BytesSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024)
+ shm_pubsub.start()
+ yield shm_pubsub
+ shm_pubsub.stop()
+
+
+def shm_bytes_msggen(size: int) -> tuple[str, bytes]:
+ """Generate raw bytes for SharedMemory transport benchmark."""
+ return ("benchmark/shm_bytes", make_data_bytes(size))
+
+
+testcases.append(
+ Case(
+ pubsub_context=shm_bytes_pubsub_channel,
+ msg_gen=shm_bytes_msggen,
+ )
+)
+
+
+@contextmanager
+def shm_lcm_pubsub_channel() -> Generator[LCMSharedMemory, None, None]:
+ """SharedMemory with LCM binary encoding - no pickle overhead."""
+ shm_pubsub = LCMSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024)
+ shm_pubsub.start()
+ yield shm_pubsub
+ shm_pubsub.stop()
+
+
+testcases.append(
+ Case(
+ pubsub_context=shm_lcm_pubsub_channel,
+ msg_gen=lcm_msggen, # Reuse the LCM message generator
+ )
+)
+
+
+try:
+ from dimos.protocol.pubsub.redispubsub import Redis
+
+ @contextmanager
+ def redis_pubsub_channel() -> Generator[Redis, None, None]:
+ redis_pubsub = Redis()
+ redis_pubsub.start()
+ yield redis_pubsub
+ redis_pubsub.stop()
+
+ def redis_msggen(size: int) -> tuple[str, Any]:
+ # Redis uses JSON serialization, so use a simple dict with base64-encoded data
+ import base64
+
+ data = base64.b64encode(make_data_bytes(size)).decode("ascii")
+ return ("benchmark/redis", {"data": data, "size": size})
+
+ testcases.append(
+ Case(
+ pubsub_context=redis_pubsub_channel,
+ msg_gen=redis_msggen,
+ )
+ )
+
+except (ConnectionError, ImportError):
+ # either redis is not installed or the server is not running
+ print("Redis not available")
+
+
+from dimos.protocol.pubsub.rospubsub import ROS_AVAILABLE, RawROS, RawROSTopic
+
+if ROS_AVAILABLE:
+ from rclpy.qos import QoSDurabilityPolicy, QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy
+ from sensor_msgs.msg import Image as ROSImage
+
+ @contextmanager
+ def ros_best_effort_pubsub_channel() -> Generator[RawROS, None, None]:
+ qos = QoSProfile(
+ reliability=QoSReliabilityPolicy.BEST_EFFORT,
+ history=QoSHistoryPolicy.KEEP_LAST,
+ durability=QoSDurabilityPolicy.VOLATILE,
+ depth=5000,
+ )
+ ros_pubsub = RawROS(node_name="benchmark_ros_best_effort", qos=qos)
+ ros_pubsub.start()
+ yield ros_pubsub
+ ros_pubsub.stop()
+
+ @contextmanager
+ def ros_reliable_pubsub_channel() -> Generator[RawROS, None, None]:
+ qos = QoSProfile(
+ reliability=QoSReliabilityPolicy.RELIABLE,
+ history=QoSHistoryPolicy.KEEP_LAST,
+ durability=QoSDurabilityPolicy.VOLATILE,
+ depth=5000,
+ )
+ ros_pubsub = RawROS(node_name="benchmark_ros_reliable", qos=qos)
+ ros_pubsub.start()
+ yield ros_pubsub
+ ros_pubsub.stop()
+
+ def ros_msggen(size: int) -> tuple[RawROSTopic, ROSImage]:
+ import numpy as np
+
+ # Create image data
+ data = np.frombuffer(make_data_bytes(size), dtype=np.uint8).reshape(-1)
+ padded_size = ((len(data) + 2) // 3) * 3
+ data = np.pad(data, (0, padded_size - len(data)))
+ pixels = len(data) // 3
+ height = max(1, int(pixels**0.5))
+ width = pixels // height
+ data = data[: height * width * 3]
+
+ # Create ROS Image message
+ msg = ROSImage()
+ msg.height = height
+ msg.width = width
+ msg.encoding = "rgb8"
+ msg.step = width * 3
+ msg.data = data.tobytes()
+
+ topic = RawROSTopic(topic="/benchmark/ros", ros_type=ROSImage)
+ return (topic, msg)
+
+ testcases.append(
+ Case(
+ pubsub_context=ros_best_effort_pubsub_channel,
+ msg_gen=ros_msggen,
+ )
+ )
+
+ testcases.append(
+ Case(
+ pubsub_context=ros_reliable_pubsub_channel,
+ msg_gen=ros_msggen,
+ )
+ )
diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py
new file mode 100644
index 0000000000..79101df9c5
--- /dev/null
+++ b/dimos/protocol/pubsub/benchmark/type.py
@@ -0,0 +1,266 @@
+#!/usr/bin/env python3
+
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable, Iterator, Sequence
+from contextlib import AbstractContextManager
+from dataclasses import dataclass, field
+from typing import Any, Generic
+
+from dimos.protocol.pubsub.spec import MsgT, PubSub, TopicT
+
+MsgGen = Callable[[int], tuple[TopicT, MsgT]]
+
+PubSubContext = Callable[[], AbstractContextManager[PubSub[TopicT, MsgT]]]
+
+
+@dataclass
+class Case(Generic[TopicT, MsgT]):
+ pubsub_context: PubSubContext[TopicT, MsgT]
+ msg_gen: MsgGen[TopicT, MsgT]
+
+ def __iter__(self) -> Iterator[PubSubContext[TopicT, MsgT] | MsgGen[TopicT, MsgT]]:
+ return iter((self.pubsub_context, self.msg_gen))
+
+ def __len__(self) -> int:
+ return 2
+
+
+TestData = Sequence[Case[Any, Any]]
+
+
+def _format_size(size_bytes: int) -> str:
+ """Format byte size to human-readable string."""
+ if size_bytes >= 1048576:
+ return f"{size_bytes / 1048576:.1f} MB"
+ if size_bytes >= 1024:
+ return f"{size_bytes / 1024:.1f} KB"
+ return f"{size_bytes} B"
+
+
+def _format_throughput(bytes_per_sec: float) -> str:
+ """Format throughput to human-readable string."""
+ if bytes_per_sec >= 1e9:
+ return f"{bytes_per_sec / 1e9:.2f} GB/s"
+ if bytes_per_sec >= 1e6:
+ return f"{bytes_per_sec / 1e6:.2f} MB/s"
+ if bytes_per_sec >= 1e3:
+ return f"{bytes_per_sec / 1e3:.2f} KB/s"
+ return f"{bytes_per_sec:.2f} B/s"
+
+
+@dataclass
+class BenchmarkResult:
+ transport: str
+ duration: float # Time spent publishing
+ msgs_sent: int
+ msgs_received: int
+ msg_size_bytes: int
+ receive_time: float = 0.0 # Time after publishing until all messages received
+
+ @property
+ def total_time(self) -> float:
+ """Total time including latency."""
+ return self.duration + self.receive_time
+
+ @property
+ def throughput_msgs(self) -> float:
+ """Messages per second (including latency)."""
+ return self.msgs_received / self.total_time if self.total_time > 0 else 0
+
+ @property
+ def throughput_bytes(self) -> float:
+ """Bytes per second (including latency)."""
+ return (
+ (self.msgs_received * self.msg_size_bytes) / self.total_time
+ if self.total_time > 0
+ else 0
+ )
+
+ @property
+ def loss_pct(self) -> float:
+ """Message loss percentage."""
+ return (1 - self.msgs_received / self.msgs_sent) * 100 if self.msgs_sent > 0 else 0
+
+
+@dataclass
+class BenchmarkResults:
+ results: list[BenchmarkResult] = field(default_factory=list)
+
+ def add(self, result: BenchmarkResult) -> None:
+ self.results.append(result)
+
+ def print_summary(self) -> None:
+ if not self.results:
+ return
+
+ from rich.console import Console
+ from rich.table import Table
+
+ console = Console()
+
+ table = Table(title="Benchmark Results")
+ table.add_column("Transport", style="cyan")
+ table.add_column("Msg Size", justify="right")
+ table.add_column("Sent", justify="right")
+ table.add_column("Recv", justify="right")
+ table.add_column("Msgs/s", justify="right", style="green")
+ table.add_column("Throughput", justify="right", style="green")
+ table.add_column("Latency", justify="right")
+ table.add_column("Loss", justify="right")
+
+ for r in sorted(self.results, key=lambda x: (x.transport, x.msg_size_bytes)):
+ loss_style = "red" if r.loss_pct > 0 else "dim"
+ recv_style = "yellow" if r.receive_time > 0.1 else "dim"
+ table.add_row(
+ r.transport,
+ _format_size(r.msg_size_bytes),
+ f"{r.msgs_sent:,}",
+ f"{r.msgs_received:,}",
+ f"{r.throughput_msgs:,.0f}",
+ _format_throughput(r.throughput_bytes),
+ f"[{recv_style}]{r.receive_time * 1000:.0f}ms[/{recv_style}]",
+ f"[{loss_style}]{r.loss_pct:.1f}%[/{loss_style}]",
+ )
+
+ console.print()
+ console.print(table)
+
+ def _print_heatmap(
+ self,
+ title: str,
+ value_fn: Callable[[BenchmarkResult], float],
+ format_fn: Callable[[float], str],
+ high_is_good: bool = True,
+ ) -> None:
+ """Generic heatmap printer."""
+ if not self.results:
+ return
+
+ def size_id(size: int) -> str:
+ if size >= 1048576:
+ return f"{size // 1048576}MB"
+ if size >= 1024:
+ return f"{size // 1024}KB"
+ return f"{size}B"
+
+ transports = sorted(set(r.transport for r in self.results))
+ sizes = sorted(set(r.msg_size_bytes for r in self.results))
+
+ # Build matrix
+ matrix: list[list[float]] = []
+ for transport in transports:
+ row = []
+ for size in sizes:
+ result = next(
+ (
+ r
+ for r in self.results
+ if r.transport == transport and r.msg_size_bytes == size
+ ),
+ None,
+ )
+ row.append(value_fn(result) if result else 0)
+ matrix.append(row)
+
+ all_vals = [v for row in matrix for v in row if v > 0]
+ if not all_vals:
+ return
+ min_val, max_val = min(all_vals), max(all_vals)
+
+ # ANSI 256 gradient: red -> orange -> yellow -> green
+ gradient = [
+ 52,
+ 88,
+ 124,
+ 160,
+ 196,
+ 202,
+ 208,
+ 214,
+ 220,
+ 226,
+ 190,
+ 154,
+ 148,
+ 118,
+ 82,
+ 46,
+ 40,
+ 34,
+ ]
+ if not high_is_good:
+ gradient = gradient[::-1]
+
+ def val_to_color(v: float) -> int:
+ if v <= 0 or max_val == min_val:
+ return 236
+ t = (v - min_val) / (max_val - min_val)
+ return gradient[int(t * (len(gradient) - 1))]
+
+ reset = "\033[0m"
+ size_labels = [size_id(s) for s in sizes]
+ col_w = max(8, max(len(s) for s in size_labels) + 1)
+ transport_w = max(len(t) for t in transports) + 1
+
+ print()
+ print(f"{title:^{transport_w + col_w * len(sizes)}}")
+ print()
+ print(" " * transport_w + "".join(f"{s:^{col_w}}" for s in size_labels))
+
+ # Dark colors that need white text (dark reds)
+ dark_colors = {52, 88, 124, 160, 236}
+
+ for i, transport in enumerate(transports):
+ row_str = f"{transport:<{transport_w}}"
+ for val in matrix[i]:
+ color = val_to_color(val)
+ fg = 255 if color in dark_colors else 16 # white on dark, black on bright
+ cell = format_fn(val) if val > 0 else "-"
+ row_str += f"\033[48;5;{color}m\033[38;5;{fg}m{cell:^{col_w}}{reset}"
+ print(row_str)
+ print()
+
+ def print_heatmap(self) -> None:
+ """Print msgs/sec heatmap."""
+
+ def fmt(v: float) -> str:
+ return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.0f}"
+
+ self._print_heatmap("Msgs/sec", lambda r: r.throughput_msgs, fmt)
+
+ def print_bandwidth_heatmap(self) -> None:
+ """Print bandwidth heatmap."""
+
+ def fmt(v: float) -> str:
+ if v >= 1e9:
+ return f"{v / 1e9:.1f}G"
+ if v >= 1e6:
+ return f"{v / 1e6:.0f}M"
+ if v >= 1e3:
+ return f"{v / 1e3:.0f}K"
+ return f"{v:.0f}"
+
+ self._print_heatmap("Bandwidth", lambda r: r.throughput_bytes, fmt)
+
+ def print_latency_heatmap(self) -> None:
+ """Print latency heatmap (time waiting for messages after publishing)."""
+
+ def fmt(v: float) -> str:
+ if v >= 1:
+ return f"{v:.1f}s"
+ return f"{v * 1000:.0f}ms"
+
+ self._print_heatmap("Latency", lambda r: r.receive_time, fmt, high_is_good=False)
diff --git a/dimos/protocol/pubsub/jpeg_shm.py b/dimos/protocol/pubsub/jpeg_shm.py
index c61848c57a..f2c9e35814 100644
--- a/dimos/protocol/pubsub/jpeg_shm.py
+++ b/dimos/protocol/pubsub/jpeg_shm.py
@@ -12,8 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from dimos.protocol.pubsub.lcmpubsub import JpegSharedMemoryEncoderMixin
+from typing import Any
+
+from turbojpeg import TurboJPEG # type: ignore[import-untyped]
+
+from dimos.msgs.sensor_msgs.Image import Image
+from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ImageFormat
from dimos.protocol.pubsub.shmpubsub import SharedMemoryPubSubBase
+from dimos.protocol.pubsub.spec import PubSubEncoderMixin
+
+
+class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image, bytes]):
+ def __init__(self, quality: int = 75, **kwargs) -> None: # type: ignore[no-untyped-def]
+ super().__init__(**kwargs)
+ self.jpeg = TurboJPEG()
+ self.quality = quality
+
+ def encode(self, msg: Any, _topic: str) -> bytes:
+ if not isinstance(msg, Image):
+ raise ValueError("Can only encode images.")
+
+ bgr_image = msg.to_bgr().to_opencv()
+ return self.jpeg.encode(bgr_image, quality=self.quality) # type: ignore[no-any-return]
+
+ def decode(self, msg: bytes, _topic: str) -> Image:
+ bgr_array = self.jpeg.decode(msg)
+ return Image(data=bgr_array, format=ImageFormat.BGR)
class JpegSharedMemory(JpegSharedMemoryEncoderMixin, SharedMemoryPubSubBase): # type: ignore[misc]
diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py
index 9207e7dfc0..471b8d6076 100644
--- a/dimos/protocol/pubsub/lcmpubsub.py
+++ b/dimos/protocol/pubsub/lcmpubsub.py
@@ -15,41 +15,29 @@
from __future__ import annotations
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
+from typing import TYPE_CHECKING, Any
-from turbojpeg import TurboJPEG # type: ignore[import-untyped]
-
-from dimos.msgs.sensor_msgs import Image
-from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ImageFormat
from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin
-from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf
+from dimos.protocol.service.lcmservice import (
+ LCMConfig,
+ LCMService,
+ autoconf,
+)
from dimos.utils.logging_config import setup_logger
if TYPE_CHECKING:
from collections.abc import Callable
import threading
-logger = setup_logger()
-
+ from dimos.msgs import DimosMsg
-@runtime_checkable
-class LCMMsg(Protocol):
- msg_name: str
-
- @classmethod
- def lcm_decode(cls, data: bytes) -> LCMMsg:
- """Decode bytes into an LCM message instance."""
- ...
-
- def lcm_encode(self) -> bytes:
- """Encode this message instance into bytes."""
- ...
+logger = setup_logger()
@dataclass
class Topic:
topic: str = ""
- lcm_type: type[LCMMsg] | None = None
+ lcm_type: type[DimosMsg] | None = None
def __str__(self) -> str:
if self.lcm_type is None:
@@ -99,11 +87,11 @@ def unsubscribe() -> None:
return unsubscribe
-class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]):
- def encode(self, msg: LCMMsg, _: Topic) -> bytes:
+class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any, bytes]):
+ def encode(self, msg: DimosMsg, _: Topic) -> bytes:
return msg.lcm_encode()
- def decode(self, msg: bytes, topic: Topic) -> LCMMsg:
+ def decode(self, msg: bytes, topic: Topic) -> DimosMsg:
if topic.lcm_type is None:
raise ValueError(
f"Cannot decode message for topic '{topic.topic}': no lcm_type specified"
@@ -111,11 +99,11 @@ def decode(self, msg: bytes, topic: Topic) -> LCMMsg:
return topic.lcm_type.lcm_decode(msg)
-class JpegEncoderMixin(PubSubEncoderMixin[Topic, Any]):
- def encode(self, msg: LCMMsg, _: Topic) -> bytes:
+class JpegEncoderMixin(PubSubEncoderMixin[Topic, Any, bytes]):
+ def encode(self, msg: DimosMsg, _: Topic) -> bytes:
return msg.lcm_jpeg_encode() # type: ignore[attr-defined, no-any-return]
- def decode(self, msg: bytes, topic: Topic) -> LCMMsg:
+ def decode(self, msg: bytes, topic: Topic) -> DimosMsg:
if topic.lcm_type is None:
raise ValueError(
f"Cannot decode message for topic '{topic.topic}': no lcm_type specified"
@@ -123,24 +111,6 @@ def decode(self, msg: bytes, topic: Topic) -> LCMMsg:
return topic.lcm_type.lcm_jpeg_decode(msg) # type: ignore[attr-defined, no-any-return]
-class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image]):
- def __init__(self, quality: int = 75, **kwargs) -> None: # type: ignore[no-untyped-def]
- super().__init__(**kwargs)
- self.jpeg = TurboJPEG()
- self.quality = quality
-
- def encode(self, msg: Any, _topic: str) -> bytes:
- if not isinstance(msg, Image):
- raise ValueError("Can only encode images.")
-
- bgr_image = msg.to_bgr().to_opencv()
- return self.jpeg.encode(bgr_image, quality=self.quality) # type: ignore[no-any-return]
-
- def decode(self, msg: bytes, _topic: str) -> Image:
- bgr_array = self.jpeg.decode(msg)
- return Image(data=bgr_array, format=ImageFormat.BGR)
-
-
class LCM(
LCMEncoderMixin,
LCMPubSubBase,
@@ -163,8 +133,6 @@ class JpegLCM(
"LCM",
"JpegLCM",
"LCMEncoderMixin",
- "LCMMsg",
- "LCMMsg",
"LCMPubSubBase",
"PickleLCM",
"autoconf",
diff --git a/dimos/protocol/pubsub/rospubsub.py b/dimos/protocol/pubsub/rospubsub.py
new file mode 100644
index 0000000000..fdb64aa257
--- /dev/null
+++ b/dimos/protocol/pubsub/rospubsub.py
@@ -0,0 +1,311 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+import threading
+from typing import Any, Protocol, runtime_checkable
+
+try:
+ import rclpy
+ from rclpy.executors import SingleThreadedExecutor
+ from rclpy.node import Node
+ from rclpy.qos import (
+ QoSDurabilityPolicy,
+ QoSHistoryPolicy,
+ QoSProfile,
+ QoSReliabilityPolicy,
+ )
+
+ ROS_AVAILABLE = True
+except ImportError:
+ ROS_AVAILABLE = False
+ rclpy = None # type: ignore[assignment]
+ SingleThreadedExecutor = None # type: ignore[assignment, misc]
+ Node = None # type: ignore[assignment, misc]
+
+import uuid
+
+from dimos.msgs import DimosMsg
+from dimos.protocol.pubsub.rospubsub_conversion import (
+ derive_ros_type,
+ dimos_to_ros,
+ ros_to_dimos,
+)
+from dimos.protocol.pubsub.spec import PubSub
+
+
+@runtime_checkable
+class ROSMessage(Protocol):
+ """Protocol for ROS message types."""
+
+ def get_fields_and_field_types(self) -> dict[str, str]: ...
+
+
+@dataclass
+class RawROSTopic:
+ """Topic descriptor for raw ROS pubsub (uses ROS types directly)."""
+
+ topic: str
+ ros_type: type
+ qos: "QoSProfile | None" = None
+
+
+@dataclass
+class ROSTopic:
+ """Topic descriptor for DimosROS pubsub (uses dimos message types)."""
+
+ topic: str
+ msg_type: type[DimosMsg]
+ qos: "QoSProfile | None" = None
+
+
+class RawROS(PubSub[RawROSTopic, Any]):
+ """ROS 2 PubSub implementation following the PubSub spec.
+
+ This allows direct comparison of ROS messaging performance against
+ native LCM and other pubsub implementations.
+ """
+
+ def __init__(self, node_name: str | None = None, qos: "QoSProfile | None" = None) -> None:
+ """Initialize the ROS pubsub.
+
+ Args:
+ node_name: Name for the ROS node (auto-generated if None)
+ qos: Optional QoS profile (defaults to BEST_EFFORT for throughput)
+ """
+ if not ROS_AVAILABLE:
+ raise ImportError("rclpy is not installed. ROS pubsub requires ROS 2.")
+
+ # Use unique node name to avoid conflicts in tests
+ self._node_name = node_name or f"dimos_ros_{uuid.uuid4().hex[:8]}"
+ self._node: Node | None = None
+ self._executor: SingleThreadedExecutor | None = None
+ self._spin_thread: threading.Thread | None = None
+ self._stop_event = threading.Event()
+
+ # Track publishers and subscriptions
+ self._publishers: dict[str, Any] = {}
+ self._subscriptions: dict[str, list[tuple[Any, Callable[[Any, RawROSTopic], None]]]] = {}
+ self._lock = threading.Lock()
+
+ # QoS profile - use provided or default to best-effort for throughput
+ if qos is not None:
+ self._qos = qos
+ else:
+ self._qos = QoSProfile(
+ # Haven't noticed any difference between BEST_EFFORT and RELIABLE for local comms in our tests
+ # ./bin/dev python -m pytest -svm tool -k ros dimos/protocol/pubsub/benchmark/test_benchmark.py
+ #
+ # but RELIABLE seems to have marginally higher throughput
+ reliability=QoSReliabilityPolicy.RELIABLE,
+ history=QoSHistoryPolicy.KEEP_LAST,
+ durability=QoSDurabilityPolicy.VOLATILE,
+ depth=5000,
+ )
+
+ def start(self) -> None:
+ """Start the ROS node and executor."""
+ if self._spin_thread is not None:
+ return
+
+ if not rclpy.ok():
+ rclpy.init()
+
+ self._stop_event.clear()
+ self._node = Node(self._node_name)
+ self._executor = SingleThreadedExecutor()
+ self._executor.add_node(self._node)
+
+ self._spin_thread = threading.Thread(target=self._spin, name="ros_pubsub_spin")
+ self._spin_thread.start()
+
+ def stop(self) -> None:
+ """Stop the ROS node and clean up."""
+ if self._spin_thread is None:
+ return
+
+ # Signal spin thread to stop and shutdown executor
+ self._stop_event.set()
+ if self._executor:
+ self._executor.shutdown() # This stops spin_once from blocking
+
+ # Wait for spin thread to exit
+ self._spin_thread.join(timeout=1.0)
+
+ # Grab references while holding lock, then destroy without lock
+ with self._lock:
+ subs_to_destroy = [
+ sub for topic_subs in self._subscriptions.values() for sub, _ in topic_subs
+ ]
+ pubs_to_destroy = list(self._publishers.values())
+ self._subscriptions.clear()
+ self._publishers.clear()
+
+ if self._node:
+ for subscription in subs_to_destroy:
+ self._node.destroy_subscription(subscription)
+ for publisher in pubs_to_destroy:
+ self._node.destroy_publisher(publisher)
+
+ if self._node:
+ self._node.destroy_node()
+ self._node = None
+
+ self._executor = None
+ self._spin_thread = None
+
+ def _spin(self) -> None:
+ """Background thread for spinning the ROS executor."""
+ while not self._stop_event.is_set():
+ executor = self._executor
+ if executor is None:
+ break
+ try:
+ executor.spin_once(timeout_sec=0.01)
+ except Exception:
+ break
+
+ def _get_or_create_publisher(self, topic: RawROSTopic) -> Any:
+ """Get existing publisher or create a new one."""
+ with self._lock:
+ if topic.topic not in self._publishers:
+ node = self._node
+ if node is None:
+ raise RuntimeError("Pubsub must be started before publishing")
+ qos = topic.qos if topic.qos is not None else self._qos
+ self._publishers[topic.topic] = node.create_publisher(
+ topic.ros_type, topic.topic, qos
+ )
+ return self._publishers[topic.topic]
+
+ def publish(self, topic: RawROSTopic, message: Any) -> None:
+ """Publish a message to a ROS topic.
+
+ Args:
+ topic: RawROSTopic descriptor with topic name and message type
+ message: ROS message to publish
+ """
+ if self._node is None:
+ return
+
+ publisher = self._get_or_create_publisher(topic)
+ publisher.publish(message)
+
+ def subscribe(
+ self, topic: RawROSTopic, callback: Callable[[Any, RawROSTopic], None]
+ ) -> Callable[[], None]:
+ """Subscribe to a ROS topic with a callback.
+
+ Args:
+ topic: RawROSTopic descriptor with topic name and message type
+ callback: Function called with (message, topic) when message received
+
+ Returns:
+ Unsubscribe function
+ """
+ if self._node is None:
+ raise RuntimeError("ROS pubsub not started")
+
+ with self._lock:
+
+ def ros_callback(msg: Any) -> None:
+ callback(msg, topic)
+
+ qos = topic.qos if topic.qos is not None else self._qos
+ subscription = self._node.create_subscription(
+ topic.ros_type, topic.topic, ros_callback, qos
+ )
+
+ if topic.topic not in self._subscriptions:
+ self._subscriptions[topic.topic] = []
+ self._subscriptions[topic.topic].append((subscription, callback))
+
+ def unsubscribe() -> None:
+ with self._lock:
+ if topic.topic in self._subscriptions:
+ self._subscriptions[topic.topic] = [
+ (sub, cb)
+ for sub, cb in self._subscriptions[topic.topic]
+ if cb is not callback
+ ]
+ if self._node:
+ self._node.destroy_subscription(subscription)
+
+ return unsubscribe
+
+
+class DimosROS(PubSub[ROSTopic, DimosMsg]):
+ """ROS PubSub with automatic dimos.msgs ↔ ROS message conversion.
+
+ Uses ROSTopic (with dimos msg_type) instead of RawROSTopic (with ros_type).
+ Automatically converts between dimos and ROS message formats.
+ Uses composition with RawROS internally.
+ """
+
+ def __init__(self, node_name: str | None = None, qos: "QoSProfile | None" = None) -> None:
+ """Initialize the DimosROS pubsub.
+
+ Args:
+ node_name: Name for the ROS node (auto-generated if None)
+ qos: Optional QoS profile (defaults to BEST_EFFORT for throughput)
+ """
+ self._raw = RawROS(node_name, qos)
+
+ def start(self) -> None:
+ """Start the ROS node and executor."""
+ self._raw.start()
+
+ def stop(self) -> None:
+ """Stop the ROS node and clean up."""
+ self._raw.stop()
+
+ def _to_raw_topic(self, topic: ROSTopic) -> RawROSTopic:
+ """Convert a ROSTopic to a RawROSTopic by deriving the ROS type."""
+ ros_type = derive_ros_type(topic.msg_type)
+ return RawROSTopic(topic=topic.topic, ros_type=ros_type, qos=topic.qos)
+
+ def publish(self, topic: ROSTopic, message: DimosMsg) -> None:
+ """Publish a dimos message to a ROS topic.
+
+ Args:
+ topic: ROSTopic with dimos msg_type
+ message: Dimos message to publish
+ """
+ raw_topic = self._to_raw_topic(topic)
+ ros_message = dimos_to_ros(message, raw_topic.ros_type)
+ self._raw.publish(raw_topic, ros_message)
+
+ def subscribe(
+ self, topic: ROSTopic, callback: Callable[[DimosMsg, ROSTopic], None]
+ ) -> Callable[[], None]:
+ """Subscribe to a ROS topic with automatic dimos message conversion.
+
+ Args:
+ topic: ROSTopic with dimos msg_type
+ callback: Function called with (dimos_message, topic)
+
+ Returns:
+ Unsubscribe function
+ """
+ raw_topic = self._to_raw_topic(topic)
+
+ def wrapped_callback(ros_msg: Any, _raw_topic: RawROSTopic) -> None:
+ dimos_msg = ros_to_dimos(ros_msg, topic.msg_type)
+ callback(dimos_msg, topic)
+
+ return self._raw.subscribe(raw_topic, wrapped_callback)
+
+
+ROS = DimosROS
diff --git a/dimos/protocol/pubsub/rospubsub_conversion.py b/dimos/protocol/pubsub/rospubsub_conversion.py
new file mode 100644
index 0000000000..18181d76b3
--- /dev/null
+++ b/dimos/protocol/pubsub/rospubsub_conversion.py
@@ -0,0 +1,365 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Conversion functions between dimos messages and ROS messages.
+
+This module provides conversion functions between dimos message types and ROS messages.
+It handles three categories of types:
+
+1. Complex types (different internal representation) - use LCM roundtrip
+2. Simple types (field structures match) - use direct field copy
+3. No dimos.msgs equivalent - return dimos_lcm type
+"""
+
+from __future__ import annotations
+
+import importlib
+import re
+from typing import TYPE_CHECKING, Any, cast
+
+if TYPE_CHECKING:
+ from dimos.msgs import DimosMsg
+ from dimos.protocol.pubsub.rospubsub import ROSMessage
+
+
+# Complex types that need LCM roundtrip (explicit list)
+# These types have different internal representations in dimos vs ROS/LCM
+COMPLEX_TYPES: set[str] = {
+ "sensor_msgs.PointCloud2",
+ "sensor_msgs.Image",
+ "sensor_msgs.CameraInfo",
+ "geometry_msgs.PoseStamped",
+}
+
+# Cache for dynamic imports of dimos types
+_dimos_type_cache: dict[str, type[DimosMsg] | None] = {}
+
+# Cache for LCM type derivation
+_lcm_type_cache: dict[str, type[Any]] = {}
+
+# Field name mappings between ROS and LCM (ROS name -> LCM name)
+# This is some mixup in dimos_lcm having ROS1 and ROS2 message definitions?
+# Would be good to clarify later, but this works for now
+_ROS_TO_LCM_FIELD_MAP: dict[str, str] = {
+ "nanosec": "nsec", # ROS2 Time.nanosec -> LCM Time.nsec
+}
+
+# Reverse mapping (LCM name -> ROS name)
+_LCM_TO_ROS_FIELD_MAP: dict[str, str] = {v: k for k, v in _ROS_TO_LCM_FIELD_MAP.items()}
+
+
+def get_dimos_type(msg_name: str) -> type[DimosMsg] | None:
+ """Try to import dimos.msgs type, return None if not found. Cached.
+
+ Args:
+ msg_name: Message name in format "package.MessageName" (e.g., "geometry_msgs.Vector3")
+
+ Returns:
+ The dimos message type, or None if not found
+ """
+ if msg_name in _dimos_type_cache:
+ return _dimos_type_cache[msg_name]
+
+ try:
+ package, name = msg_name.split(".")
+ module = importlib.import_module(f"dimos.msgs.{package}.{name}")
+ dimos_type = cast("type[DimosMsg]", getattr(module, name))
+ _dimos_type_cache[msg_name] = dimos_type
+ return dimos_type
+ except (ImportError, AttributeError, ValueError):
+ _dimos_type_cache[msg_name] = None
+ return None
+
+
+def derive_lcm_type(dimos_type: type[DimosMsg]) -> type[Any]:
+ """Derive the LCM message type from a dimos message type.
+
+ Args:
+ dimos_type: A dimos message type (e.g., dimos.msgs.sensor_msgs.PointCloud2)
+
+ Returns:
+ The corresponding LCM message type (e.g., dimos_lcm.sensor_msgs.PointCloud2)
+ """
+ msg_name = dimos_type.msg_name # e.g., "sensor_msgs.PointCloud2"
+
+ if msg_name in _lcm_type_cache:
+ return _lcm_type_cache[msg_name]
+
+ parts = msg_name.split(".")
+ if len(parts) != 2:
+ raise ValueError(f"Invalid msg_name format: {msg_name}, expected 'package.MessageName'")
+
+ package, message_name = parts
+ lcm_module = importlib.import_module(f"dimos_lcm.{package}.{message_name}")
+ lcm_type: type[Any] = getattr(lcm_module, message_name)
+ _lcm_type_cache[msg_name] = lcm_type
+ return lcm_type
+
+
+def derive_ros_type(dimos_type: type[DimosMsg]) -> type[ROSMessage]:
+ """Derive the ROS message type from a dimos message type.
+
+ Args:
+ dimos_type: A dimos message type (e.g., dimos.msgs.geometry_msgs.Vector3)
+
+ Returns:
+ The corresponding ROS message type (e.g., geometry_msgs.msg.Vector3)
+
+ Example:
+ msg_name = "geometry_msgs.Vector3" -> geometry_msgs.msg.Vector3
+ """
+ msg_name = dimos_type.msg_name # e.g., "geometry_msgs.Vector3"
+ parts = msg_name.split(".")
+ if len(parts) != 2:
+ raise ValueError(f"Invalid msg_name format: {msg_name}, expected 'package.MessageName'")
+
+ package, message_name = parts
+ ros_module = importlib.import_module(f"{package}.msg")
+ return cast("type[ROSMessage]", getattr(ros_module, message_name))
+
+
+def _copy_ros_to_lcm_recursive(ros_msg: Any, lcm_msg: Any) -> None:
+ """Recursively copy fields from ROS message to LCM message.
+
+ Handles nested messages, arrays, and primitive types.
+
+ Args:
+ ros_msg: Source ROS message
+ lcm_msg: Target LCM message (modified in place)
+ """
+ if not hasattr(ros_msg, "get_fields_and_field_types"):
+ raise TypeError(f"Expected ROS message, got {type(ros_msg).__name__}")
+
+ field_types = ros_msg.get_fields_and_field_types()
+ for ros_field_name in field_types:
+ # Map ROS field name to LCM field name
+ lcm_field_name = _ROS_TO_LCM_FIELD_MAP.get(ros_field_name, ros_field_name)
+
+ if not hasattr(lcm_msg, lcm_field_name):
+ continue
+
+ ros_value = getattr(ros_msg, ros_field_name)
+ lcm_value = getattr(lcm_msg, lcm_field_name)
+
+ # Handle nested messages
+ if hasattr(ros_value, "get_fields_and_field_types"):
+ _copy_ros_to_lcm_recursive(ros_value, lcm_value)
+ # Handle arrays of messages
+ elif isinstance(ros_value, (list, tuple)) and len(ros_value) > 0:
+ if hasattr(ros_value[0], "get_fields_and_field_types"):
+ # Array of nested messages - create LCM instances
+ lcm_array = []
+ for ros_item in ros_value:
+ # Get the LCM element type from the first lcm_value element if available
+ # Otherwise try to derive from ros item
+ if isinstance(lcm_value, list) and len(lcm_value) > 0:
+ lcm_item = type(lcm_value[0])()
+ else:
+ # Try to create matching LCM type
+ lcm_item = _create_lcm_instance_for_ros_msg(ros_item)
+ _copy_ros_to_lcm_recursive(ros_item, lcm_item)
+ lcm_array.append(lcm_item)
+ setattr(lcm_msg, lcm_field_name, lcm_array)
+ else:
+ # Array of primitives - direct copy
+ setattr(lcm_msg, lcm_field_name, list(ros_value))
+ # Handle bytes/data fields
+ elif isinstance(ros_value, (bytes, bytearray)):
+ setattr(lcm_msg, lcm_field_name, bytes(ros_value))
+ # Handle array.array (ROS uses this for data fields)
+ elif hasattr(ros_value, "tobytes"):
+ setattr(lcm_msg, lcm_field_name, ros_value.tobytes())
+ else:
+ # Primitive type - direct copy
+ setattr(lcm_msg, lcm_field_name, ros_value)
+
+ # Update length fields if present (LCM convention: field_name_length)
+ length_field = f"{lcm_field_name}_length"
+ if hasattr(lcm_msg, length_field):
+ value = getattr(lcm_msg, lcm_field_name)
+ if isinstance(value, (list, tuple, bytes, bytearray)):
+ setattr(lcm_msg, length_field, len(value))
+
+
+def _copy_lcm_to_ros_recursive(lcm_msg: Any, ros_msg: Any) -> None:
+ """Recursively copy fields from LCM message to ROS message.
+
+ Handles nested messages, arrays, and primitive types.
+
+ Args:
+ lcm_msg: Source LCM message
+ ros_msg: Target ROS message (modified in place)
+ """
+ if not hasattr(ros_msg, "get_fields_and_field_types"):
+ raise TypeError(f"Expected ROS message, got {type(ros_msg).__name__}")
+
+ field_types = ros_msg.get_fields_and_field_types()
+ for ros_field_name in field_types:
+ # Map ROS field name to LCM field name
+ lcm_field_name = _ROS_TO_LCM_FIELD_MAP.get(ros_field_name, ros_field_name)
+
+ if not hasattr(lcm_msg, lcm_field_name):
+ continue
+
+ lcm_value = getattr(lcm_msg, lcm_field_name)
+ ros_value = getattr(ros_msg, ros_field_name)
+
+ # Handle nested messages
+ if hasattr(ros_value, "get_fields_and_field_types"):
+ _copy_lcm_to_ros_recursive(lcm_value, ros_value)
+ # Handle arrays of messages
+ elif isinstance(lcm_value, (list, tuple)) and len(lcm_value) > 0:
+ if hasattr(lcm_value[0], "lcm_encode"):
+ # Array of nested LCM messages
+ ros_array = []
+ for lcm_item in lcm_value:
+ ros_item = _create_ros_instance_for_lcm_msg(
+ lcm_item, field_types[ros_field_name]
+ )
+ _copy_lcm_to_ros_recursive(lcm_item, ros_item)
+ ros_array.append(ros_item)
+ setattr(ros_msg, ros_field_name, ros_array)
+ else:
+ # Array of primitives - direct copy
+ setattr(ros_msg, ros_field_name, list(lcm_value))
+ # Handle bytes/data fields
+ elif isinstance(lcm_value, (bytes, bytearray)):
+ # ROS data fields might expect array.array
+ if hasattr(ros_value, "frombytes"):
+ import array
+
+ arr = array.array("B")
+ arr.frombytes(lcm_value)
+ setattr(ros_msg, ros_field_name, arr)
+ else:
+ setattr(ros_msg, ros_field_name, bytes(lcm_value))
+ else:
+ # Primitive type - direct copy
+ setattr(ros_msg, ros_field_name, lcm_value)
+
+
+def _create_lcm_instance_for_ros_msg(ros_msg: Any) -> Any:
+ """Create an LCM message instance that matches the ROS message type.
+
+ Args:
+ ros_msg: ROS message to match
+
+ Returns:
+ New LCM message instance
+ """
+ # Get the ROS type name (e.g., "std_msgs.msg.Header" -> "std_msgs.Header")
+ ros_type = type(ros_msg)
+ module_name = ros_type.__module__ # e.g., "std_msgs.msg"
+ class_name = ros_type.__name__ # e.g., "Header"
+
+ # Convert to LCM module path (std_msgs.msg.Header -> dimos_lcm.std_msgs.Header)
+ package = module_name.split(".")[0] # e.g., "std_msgs"
+ lcm_module = importlib.import_module(f"dimos_lcm.{package}.{class_name}")
+ lcm_type = getattr(lcm_module, class_name)
+ return lcm_type()
+
+
+def _create_ros_instance_for_lcm_msg(lcm_msg: Any, ros_type_hint: str) -> Any:
+ """Create a ROS message instance that matches the LCM message type.
+
+ Args:
+ lcm_msg: LCM message to match
+ ros_type_hint: ROS type hint string (e.g., "sequence")
+
+ Returns:
+ New ROS message instance
+ """
+ # Parse the type hint to get the message type
+ # e.g., "sequence" -> "sensor_msgs", "PointField"
+ # e.g., "sensor_msgs/PointField" -> "sensor_msgs", "PointField"
+
+ match = re.search(r"(\w+)/(\w+)", ros_type_hint)
+ if match:
+ package, class_name = match.groups()
+ ros_module = importlib.import_module(f"{package}.msg")
+ ros_type = getattr(ros_module, class_name)
+ return ros_type()
+
+ # Fallback: try to derive from LCM type
+ lcm_type = type(lcm_msg)
+ module_name = lcm_type.__module__ # e.g., "dimos_lcm.std_msgs.Header"
+ class_name = lcm_type.__name__
+ parts = module_name.split(".")
+ if len(parts) >= 2:
+ package = parts[1] # e.g., "std_msgs"
+ ros_module = importlib.import_module(f"{package}.msg")
+ ros_type = getattr(ros_module, class_name)
+ return ros_type()
+
+ raise ValueError(f"Cannot determine ROS type for LCM message: {lcm_type}")
+
+
+def dimos_to_ros(msg: DimosMsg, ros_type: type[ROSMessage]) -> ROSMessage:
+ """Convert a dimos message to a ROS message.
+
+ For complex types (PointCloud2, Image, CameraInfo), uses LCM roundtrip
+ to properly convert internal representations. For simple types, uses
+ direct field copy.
+
+ Args:
+ msg: Dimos message instance
+ ros_type: Target ROS message type
+
+ Returns:
+ ROS message instance
+ """
+ msg_name = type(msg).msg_name
+
+ if msg_name in COMPLEX_TYPES:
+ # Complex: dimos → encode → decode LCM → copy to ROS
+ lcm_type = derive_lcm_type(type(msg))
+ lcm_bytes = msg.lcm_encode()
+ lcm_msg = lcm_type.lcm_decode(lcm_bytes)
+ ros_msg = ros_type()
+ _copy_lcm_to_ros_recursive(lcm_msg, ros_msg)
+ return ros_msg
+
+ # Simple: recursive field copy (handles nested messages)
+ ros_msg = ros_type()
+ _copy_lcm_to_ros_recursive(msg, ros_msg)
+ return ros_msg
+
+
+def ros_to_dimos(msg: Any, dimos_type: type[DimosMsg]) -> DimosMsg:
+ """Convert a ROS message to a dimos message.
+
+ For complex types (PointCloud2, Image, CameraInfo), uses LCM roundtrip
+ to properly build the dimos internal representation. For simple types,
+ uses direct field copy.
+
+ Args:
+ msg: ROS message instance
+ dimos_type: Target dimos message type
+
+ Returns:
+ Dimos message instance
+ """
+ msg_name = dimos_type.msg_name
+
+ if msg_name in COMPLEX_TYPES:
+ # Complex: ROS → LCM → encode → decode → dimos
+ lcm_type = derive_lcm_type(dimos_type)
+ lcm_msg = lcm_type()
+ _copy_ros_to_lcm_recursive(msg, lcm_msg)
+ return dimos_type.lcm_decode(lcm_msg.lcm_encode())
+
+ # Simple type: recursive field copy (handles nested messages)
+ dimos_msg = dimos_type()
+ _copy_ros_to_lcm_recursive(msg, dimos_msg)
+ return dimos_msg
diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py
index 5f69c3dbd1..5f0b20165e 100644
--- a/dimos/protocol/pubsub/shm/ipc_factory.py
+++ b/dimos/protocol/pubsub/shm/ipc_factory.py
@@ -69,8 +69,13 @@ def shape(self) -> tuple: ... # type: ignore[type-arg]
def dtype(self) -> np.dtype: ... # type: ignore[type-arg]
@abstractmethod
- def publish(self, frame) -> None: # type: ignore[no-untyped-def]
- """Write into inactive buffer, then flip visible index (write control last)."""
+ def publish(self, frame, length: int | None = None) -> None: # type: ignore[no-untyped-def]
+ """Write into inactive buffer, then flip visible index (write control last).
+
+ Args:
+ frame: The numpy array to publish
+ length: Optional length to copy (for variable-size messages). If None, copies full frame.
+ """
...
@abstractmethod
@@ -185,7 +190,7 @@ def shape(self): # type: ignore[no-untyped-def]
def dtype(self): # type: ignore[no-untyped-def]
return self._dtype
- def publish(self, frame) -> None: # type: ignore[no-untyped-def]
+ def publish(self, frame, length: int | None = None) -> None: # type: ignore[no-untyped-def]
assert isinstance(frame, np.ndarray)
assert frame.shape == self._shape and frame.dtype == self._dtype
active = int(self._ctrl[2])
@@ -196,7 +201,11 @@ def publish(self, frame) -> None: # type: ignore[no-untyped-def]
buffer=self._shm_data.buf,
offset=inactive * self._nbytes,
)
- np.copyto(view, frame, casting="no")
+ # Only copy actual payload length if specified, otherwise copy full frame
+ if length is not None and length < len(frame):
+ np.copyto(view[:length], frame[:length], casting="no")
+ else:
+ np.copyto(view, frame, casting="no")
ts = np.int64(time.time_ns())
# Publish order: ts -> idx -> seq
self._ctrl[1] = ts
diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py
index 0006020f6c..89efb82ac3 100644
--- a/dimos/protocol/pubsub/shmpubsub.py
+++ b/dimos/protocol/pubsub/shmpubsub.py
@@ -30,7 +30,9 @@
import uuid
import numpy as np
+import numpy.typing as npt
+from dimos.protocol.pubsub.lcmpubsub import LCMEncoderMixin, Topic
from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel
from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin
from dimos.utils.logging_config import setup_logger
@@ -41,11 +43,6 @@
logger = setup_logger()
-# --------------------------------------------------------------------------------------
-# Configuration (kept local to PubSub now that Service is gone)
-# --------------------------------------------------------------------------------------
-
-
@dataclass
class SharedMemoryConfig:
prefer: str = "auto" # "auto" | "cpu" (DIMOS_IPC_BACKEND overrides), TODO: "cuda"
@@ -53,11 +50,6 @@ class SharedMemoryConfig:
close_channels_on_stop: bool = True
-# --------------------------------------------------------------------------------------
-# Core PubSub with integrated SHM/IPC transport (previously the Service logic)
-# --------------------------------------------------------------------------------------
-
-
class SharedMemoryPubSubBase(PubSub[str, Any]):
"""
Pub/Sub over SharedMemory/CUDA-IPC, modeled after LCMPubSubBase but self-contained.
@@ -81,6 +73,8 @@ class _TopicState:
"dtype",
"last_local_payload",
"last_seq",
+ "publish_buffer",
+ "publish_lock",
"shape",
"stop",
"subs",
@@ -101,6 +95,10 @@ def __init__(self, channel, capacity: int, cp_mod) -> None: # type: ignore[no-u
self.cp = cp_mod
self.last_local_payload: bytes | None = None
self.suppress_counts: dict[bytes, int] = defaultdict(int) # UUID bytes as key
+ # Pre-allocated buffer to avoid allocation on every publish
+ self.publish_buffer: npt.NDArray[np.uint8] = np.zeros(self.shape, dtype=self.dtype)
+ # Lock for thread-safe publish buffer access
+ self.publish_lock = threading.Lock()
# ----- init / lifecycle -------------------------------------------------
@@ -124,7 +122,7 @@ def __init__(
def start(self) -> None:
pref = (self.config.prefer or "auto").lower()
backend = os.getenv("DIMOS_IPC_BACKEND", pref).lower()
- logger.info(f"SharedMemory PubSub starting (backend={backend})")
+ logger.debug(f"SharedMemory PubSub starting (backend={backend})")
# No global thread needed; per-topic fanout starts on first subscribe.
def stop(self) -> None:
@@ -145,7 +143,7 @@ def stop(self) -> None:
except Exception:
pass
self._topics.clear()
- logger.info("SharedMemory PubSub stopped.")
+ logger.debug("SharedMemory PubSub stopped.")
# ----- PubSub API (bytes on the wire) ----------------------------------
@@ -178,15 +176,19 @@ def publish(self, topic: str, message: bytes) -> None:
# Build host frame [len:4] + [uuid:16] + payload and publish
# We embed the message UUID in the frame for echo suppression
- host = np.zeros(st.shape, dtype=st.dtype)
- # Pack: length(4) + uuid(16) + payload
- header = struct.pack(" Callable[[], None]:
"""Subscribe a callback(message: bytes, topic). Returns unsubscribe."""
@@ -216,11 +218,14 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[typ
st = self._ensure_topic(topic)
new_cap = int(capacity)
new_shape = (new_cap + 20,) # +20 for header: length(4) + uuid(16)
- desc = st.channel.reconfigure(new_shape, np.uint8)
- st.capacity = new_cap
- st.shape = new_shape
- st.dtype = np.uint8
- st.last_seq = -1
+ # Lock to ensure no publish is using the buffer while we replace it
+ with st.publish_lock:
+ desc = st.channel.reconfigure(new_shape, np.uint8)
+ st.capacity = new_cap
+ st.shape = new_shape
+ st.dtype = np.uint8
+ st.last_seq = -1
+ st.publish_buffer = np.zeros(new_shape, dtype=np.uint8)
return desc # type: ignore[no-any-return]
# ----- Internals --------------------------------------------------------
@@ -290,36 +295,50 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None:
pass
-# --------------------------------------------------------------------------------------
-# Encoders + concrete PubSub classes
-# --------------------------------------------------------------------------------------
+BytesSharedMemory = SharedMemoryPubSubBase
-class SharedMemoryBytesEncoderMixin(PubSubEncoderMixin[str, bytes]):
- """Identity encoder for raw bytes."""
+class PickleSharedMemory(
+ PickleEncoderMixin[str, Any],
+ SharedMemoryPubSubBase,
+):
+ """SharedMemory pubsub that transports arbitrary Python objects via pickle."""
- def encode(self, msg: bytes, _: str) -> bytes:
- if isinstance(msg, bytes | bytearray | memoryview):
- return bytes(msg)
- raise TypeError(f"SharedMemory expects bytes-like, got {type(msg)!r}")
+ ...
- def decode(self, msg: bytes, _: str) -> bytes:
- return msg
+class LCMSharedMemoryPubSubBase(PubSub[Topic, Any]):
+ """SharedMemory pubsub that uses LCM Topic type, delegating to SharedMemoryPubSubBase."""
-class SharedMemory(
- SharedMemoryBytesEncoderMixin,
- SharedMemoryPubSubBase,
-):
- """SharedMemory pubsub that transports raw bytes."""
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__()
+ self._shm = SharedMemoryPubSubBase(**kwargs)
- ...
+ def start(self) -> None:
+ self._shm.start()
+ def stop(self) -> None:
+ self._shm.stop()
-class PickleSharedMemory(
- PickleEncoderMixin[str, Any],
- SharedMemoryPubSubBase,
+ def publish(self, topic: Topic, message: bytes) -> None:
+ self._shm.publish(str(topic), message)
+
+ def subscribe(
+ self, topic: Topic, callback: Callable[[bytes, Topic], Any]
+ ) -> Callable[[], None]:
+ def wrapper(msg: bytes, _: str) -> None:
+ callback(msg, topic)
+
+ return self._shm.subscribe(str(topic), wrapper)
+
+ def reconfigure(self, topic: Topic, *, capacity: int) -> dict: # type: ignore[type-arg]
+ return self._shm.reconfigure(str(topic), capacity=capacity)
+
+
+class LCMSharedMemory(
+ LCMEncoderMixin,
+ LCMSharedMemoryPubSubBase,
):
- """SharedMemory pubsub that transports arbitrary Python objects via pickle."""
+ """SharedMemory pubsub that uses LCM binary encoding (no pickle overhead)."""
...
diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py
index 28fce3faee..b4e82d3993 100644
--- a/dimos/protocol/pubsub/spec.py
+++ b/dimos/protocol/pubsub/spec.py
@@ -20,13 +20,9 @@
import pickle
from typing import Any, Generic, TypeVar
-from dimos.utils.logging_config import setup_logger
-
MsgT = TypeVar("MsgT")
TopicT = TypeVar("TopicT")
-
-
-logger = setup_logger()
+EncodingT = TypeVar("EncodingT")
class PubSub(Generic[TopicT, MsgT], ABC):
@@ -96,7 +92,7 @@ def _queue_cb(msg: MsgT, topic: TopicT) -> None:
unsubscribe_fn()
-class PubSubEncoderMixin(Generic[TopicT, MsgT], ABC):
+class PubSubEncoderMixin(Generic[TopicT, MsgT, EncodingT], ABC):
"""Mixin that encodes messages before publishing and decodes them after receiving.
Usage: Just specify encoder and decoder as a subclass:
@@ -109,10 +105,10 @@ def decoder(msg, topic):
"""
@abstractmethod
- def encode(self, msg: MsgT, topic: TopicT) -> bytes: ...
+ def encode(self, msg: MsgT, topic: TopicT) -> EncodingT: ...
@abstractmethod
- def decode(self, msg: bytes, topic: TopicT) -> MsgT: ...
+ def decode(self, msg: EncodingT, topic: TopicT) -> MsgT: ...
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs)
@@ -132,14 +128,14 @@ def subscribe(
) -> Callable[[], None]:
"""Subscribe with automatic decoding."""
- def wrapper_cb(encoded_data: bytes, topic: TopicT) -> None:
+ def wrapper_cb(encoded_data: EncodingT, topic: TopicT) -> None:
decoded_message = self.decode(encoded_data, topic)
callback(decoded_message, topic)
return super().subscribe(topic, wrapper_cb) # type: ignore[misc, no-any-return]
-class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT]):
+class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT, bytes]):
def encode(self, msg: MsgT, *_: TopicT) -> bytes: # type: ignore[return]
try:
return pickle.dumps(msg)
diff --git a/dimos/protocol/pubsub/test_encoder.py b/dimos/protocol/pubsub/test_encoder.py
index f39bd170d5..38aac4664d 100644
--- a/dimos/protocol/pubsub/test_encoder.py
+++ b/dimos/protocol/pubsub/test_encoder.py
@@ -15,6 +15,7 @@
# limitations under the License.
import json
+from typing import Any
from dimos.protocol.pubsub.memory import Memory, MemoryWithJSONEncoder
@@ -24,7 +25,7 @@ def test_json_encoded_pubsub() -> None:
pubsub = MemoryWithJSONEncoder()
received_messages = []
- def callback(message, topic) -> None:
+ def callback(message: Any, topic: str) -> None:
received_messages.append(message)
# Subscribe to a topic
@@ -56,7 +57,7 @@ def test_json_encoding_edge_cases() -> None:
pubsub = MemoryWithJSONEncoder()
received_messages = []
- def callback(message, topic) -> None:
+ def callback(message: Any, topic: str) -> None:
received_messages.append(message)
pubsub.subscribe("edge_cases", callback)
@@ -84,10 +85,10 @@ def test_multiple_subscribers_with_encoding() -> None:
received_messages_1 = []
received_messages_2 = []
- def callback_1(message, topic) -> None:
+ def callback_1(message: Any, topic: str) -> None:
received_messages_1.append(message)
- def callback_2(message, topic) -> None:
+ def callback_2(message: Any, topic: str) -> None:
received_messages_2.append(f"callback_2: {message}")
pubsub.subscribe("json_topic", callback_1)
@@ -130,9 +131,9 @@ def test_data_actually_encoded_in_transit() -> None:
class SpyMemory(Memory):
def __init__(self) -> None:
super().__init__()
- self.raw_messages_received = []
+ self.raw_messages_received: list[tuple[str, Any, type]] = []
- def publish(self, topic: str, message) -> None:
+ def publish(self, topic: str, message: Any) -> None:
# Capture what actually gets published
self.raw_messages_received.append((topic, message, type(message)))
super().publish(topic, message)
@@ -142,9 +143,9 @@ class SpyMemoryWithJSON(MemoryWithJSONEncoder, SpyMemory):
pass
pubsub = SpyMemoryWithJSON()
- received_decoded = []
+ received_decoded: list[Any] = []
- def callback(message, topic) -> None:
+ def callback(message: Any, topic: str) -> None:
received_decoded.append(message)
pubsub.subscribe("test_topic", callback)
diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py
index d06bf20716..8165be9fef 100644
--- a/dimos/protocol/pubsub/test_lcmpubsub.py
+++ b/dimos/protocol/pubsub/test_lcmpubsub.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections.abc import Generator
import time
+from typing import Any
import pytest
@@ -26,7 +28,7 @@
@pytest.fixture
-def lcm_pub_sub_base():
+def lcm_pub_sub_base() -> Generator[LCMPubSubBase, None, None]:
lcm = LCMPubSubBase(autoconf=True)
lcm.start()
yield lcm
@@ -34,7 +36,7 @@ def lcm_pub_sub_base():
@pytest.fixture
-def pickle_lcm():
+def pickle_lcm() -> Generator[PickleLCM, None, None]:
lcm = PickleLCM(autoconf=True)
lcm.start()
yield lcm
@@ -42,7 +44,7 @@ def pickle_lcm():
@pytest.fixture
-def lcm():
+def lcm() -> Generator[LCM, None, None]:
lcm = LCM(autoconf=True)
lcm.start()
yield lcm
@@ -54,7 +56,7 @@ class MockLCMMessage:
msg_name = "geometry_msgs.Mock"
- def __init__(self, data) -> None:
+ def __init__(self, data: Any) -> None:
self.data = data
def lcm_encode(self) -> bytes:
@@ -64,19 +66,19 @@ def lcm_encode(self) -> bytes:
def lcm_decode(cls, data: bytes) -> "MockLCMMessage":
return cls(data.decode("utf-8"))
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return isinstance(other, MockLCMMessage) and self.data == other.data
-def test_LCMPubSubBase_pubsub(lcm_pub_sub_base) -> None:
+def test_LCMPubSubBase_pubsub(lcm_pub_sub_base: LCMPubSubBase) -> None:
lcm = lcm_pub_sub_base
- received_messages = []
+ received_messages: list[tuple[Any, Any]] = []
topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage)
test_message = MockLCMMessage("test_data")
- def callback(msg, topic) -> None:
+ def callback(msg: Any, topic: Any) -> None:
received_messages.append((msg, topic))
lcm.subscribe(topic, callback)
@@ -97,13 +99,13 @@ def callback(msg, topic) -> None:
assert received_topic == topic
-def test_lcm_autodecoder_pubsub(lcm) -> None:
- received_messages = []
+def test_lcm_autodecoder_pubsub(lcm: LCM) -> None:
+ received_messages: list[tuple[Any, Any]] = []
topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage)
test_message = MockLCMMessage("test_data")
- def callback(msg, topic) -> None:
+ def callback(msg: Any, topic: Any) -> None:
received_messages.append((msg, topic))
lcm.subscribe(topic, callback)
@@ -133,12 +135,12 @@ def callback(msg, topic) -> None:
# passes some geometry types through LCM
@pytest.mark.parametrize("test_message", test_msgs)
-def test_lcm_geometry_msgs_pubsub(test_message, lcm) -> None:
- received_messages = []
+def test_lcm_geometry_msgs_pubsub(test_message: Any, lcm: LCM) -> None:
+ received_messages: list[tuple[Any, Any]] = []
topic = Topic(topic="/test_topic", lcm_type=test_message.__class__)
- def callback(msg, topic) -> None:
+ def callback(msg: Any, topic: Any) -> None:
received_messages.append((msg, topic))
lcm.subscribe(topic, callback)
@@ -164,13 +166,13 @@ def callback(msg, topic) -> None:
# passes some geometry types through pickle LCM
@pytest.mark.parametrize("test_message", test_msgs)
-def test_lcm_geometry_msgs_autopickle_pubsub(test_message, pickle_lcm) -> None:
+def test_lcm_geometry_msgs_autopickle_pubsub(test_message: Any, pickle_lcm: PickleLCM) -> None:
lcm = pickle_lcm
- received_messages = []
+ received_messages: list[tuple[Any, Any]] = []
topic = Topic(topic="/test_topic")
- def callback(msg, topic) -> None:
+ def callback(msg: Any, topic: Any) -> None:
received_messages.append((msg, topic))
lcm.subscribe(topic, callback)
diff --git a/dimos/protocol/pubsub/test_rospubsub.py b/dimos/protocol/pubsub/test_rospubsub.py
new file mode 100644
index 0000000000..3a3a020586
--- /dev/null
+++ b/dimos/protocol/pubsub/test_rospubsub.py
@@ -0,0 +1,286 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Generator
+import threading
+
+from dimos_lcm.geometry_msgs import PointStamped
+import numpy as np
+import pytest
+
+from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped
+from dimos.msgs.geometry_msgs.Twist import Twist
+from dimos.msgs.geometry_msgs.Vector3 import Vector3
+from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2
+from dimos.protocol.pubsub.rospubsub import DimosROS, ROSTopic
+
+# Add msg_name to LCM PointStamped for testing nested message conversion
+PointStamped.msg_name = "geometry_msgs.PointStamped"
+from dimos.utils.data import get_data
+from dimos.utils.testing import TimedSensorReplay
+
+
+def ros_node():
+ ros = DimosROS()
+ ros.start()
+ try:
+ yield ros
+ finally:
+ ros.stop()
+
+
+@pytest.fixture()
+def publisher() -> Generator[DimosROS, None, None]:
+ yield from ros_node()
+
+
+@pytest.fixture()
+def subscriber() -> Generator[DimosROS, None, None]:
+ yield from ros_node()
+
+
+@pytest.mark.ros
+def test_basic_conversion(publisher, subscriber):
+ """Test Vector3 publish/subscribe through ROS.
+
+ Simple flat dimos.msgs type with no nesting (just x/y/z floats).
+ """
+ topic = ROSTopic("/test_ros_topic", Vector3)
+
+ received = []
+ event = threading.Event()
+
+ def callback(msg, t):
+ received.append(msg)
+ event.set()
+
+ subscriber.subscribe(topic, callback)
+ publisher.publish(topic, Vector3(1.0, 2.0, 3.0))
+
+ assert event.wait(timeout=2.0), "No message received"
+ assert len(received) == 1
+ msg = received[0]
+ assert msg.x == 1.0
+ assert msg.y == 2.0
+ assert msg.z == 3.0
+
+
+@pytest.mark.ros
+def test_pointcloud2_pubsub(publisher, subscriber):
+ """Test PointCloud2 publish/subscribe through ROS.
+
+ COMPLEX_TYPE - has non-standard attributes (numpy arrays, custom accessors)
+ that can't be treated like a standard message with direct field copy.
+ Uses LCM encode/decode roundtrip to properly convert internal representation.
+ """
+ dir_name = get_data("unitree_go2_bigoffice")
+
+ # Load real lidar data from replay (5 seconds in)
+ replay = TimedSensorReplay(f"{dir_name}/lidar")
+ original = replay.find_closest_seek(5.0)
+
+ assert original is not None, "Failed to load lidar data from replay"
+ assert len(original) > 0, "Loaded empty pointcloud"
+
+ topic = ROSTopic("/test_pointcloud2", PointCloud2)
+
+ received = []
+ event = threading.Event()
+
+ def callback(msg, t):
+ received.append(msg)
+ event.set()
+
+ subscriber.subscribe(topic, callback)
+ publisher.publish(topic, original)
+
+ assert event.wait(timeout=5.0), "No PointCloud2 message received"
+ assert len(received) == 1
+
+ converted = received[0]
+
+ # Verify point cloud data is preserved
+ original_points, _ = original.as_numpy()
+ converted_points, _ = converted.as_numpy()
+
+ assert len(original_points) == len(converted_points), (
+ f"Point count mismatch: {len(original_points)} vs {len(converted_points)}"
+ )
+
+ np.testing.assert_allclose(
+ original_points,
+ converted_points,
+ rtol=1e-5,
+ atol=1e-5,
+ err_msg="Points don't match after ROS pubsub roundtrip",
+ )
+
+ # Verify frame_id is preserved
+ assert converted.frame_id == original.frame_id
+
+ # Verify timestamp is preserved (within 1ms tolerance)
+ assert abs(original.ts - converted.ts) < 0.001
+
+
+@pytest.mark.ros
+def test_pointcloud2_empty_pubsub(publisher, subscriber):
+ """Test empty PointCloud2 publish/subscribe.
+
+ Edge case for COMPLEX_TYPE with zero points.
+ """
+ original = PointCloud2.from_numpy(
+ np.array([]).reshape(0, 3),
+ frame_id="empty_frame",
+ timestamp=1234567890.0,
+ )
+
+ topic = ROSTopic("/test_empty_pointcloud", PointCloud2)
+
+ received = []
+ event = threading.Event()
+
+ def callback(msg, t):
+ received.append(msg)
+ event.set()
+
+ subscriber.subscribe(topic, callback)
+ publisher.publish(topic, original)
+
+ assert event.wait(timeout=2.0), "No empty PointCloud2 message received"
+ assert len(received) == 1
+ assert len(received[0]) == 0
+
+
+@pytest.mark.ros
+def test_posestamped_pubsub(publisher, subscriber):
+ """Test PoseStamped publish/subscribe through ROS.
+
+ COMPLEX_TYPE with custom dimos.msgs implementation and nested messages
+ (Header, Pose containing Point and Quaternion). Uses LCM roundtrip.
+ """
+ original = PoseStamped(
+ ts=1234567890.123456,
+ frame_id="base_link",
+ position=[1.0, 2.0, 3.0],
+ orientation=[0.0, 0.0, 0.7071068, 0.7071068], # 90 degree yaw
+ )
+
+ topic = ROSTopic("/test_posestamped", PoseStamped)
+
+ received = []
+ event = threading.Event()
+
+ def callback(msg, t):
+ received.append(msg)
+ event.set()
+
+ subscriber.subscribe(topic, callback)
+ publisher.publish(topic, original)
+
+ assert event.wait(timeout=2.0), "No PoseStamped message received"
+ assert len(received) == 1
+
+ converted = received[0]
+
+ # Verify all fields preserved
+ assert converted.frame_id == original.frame_id
+ assert abs(converted.ts - original.ts) < 0.001 # 1ms tolerance
+ assert converted.x == original.x
+ assert converted.y == original.y
+ assert converted.z == original.z
+ np.testing.assert_allclose(converted.orientation.z, original.orientation.z, rtol=1e-5)
+ np.testing.assert_allclose(converted.orientation.w, original.orientation.w, rtol=1e-5)
+
+
+@pytest.mark.ros
+def test_pointstamped_pubsub(publisher, subscriber):
+ """Test PointStamped publish/subscribe through ROS.
+
+ Raw LCM type with nested messages (Header, Point) but NO custom dimos.msgs
+ implementation. Tests recursive field copy for non-COMPLEX_TYPES.
+ """
+ original = PointStamped()
+ original.header.stamp.sec = 1234567890
+ original.header.stamp.nsec = 123456000
+ original.header.frame_id = "map"
+ original.point.x = 1.5
+ original.point.y = 2.5
+ original.point.z = 3.5
+
+ topic = ROSTopic("/test_pointstamped", PointStamped)
+
+ received = []
+ event = threading.Event()
+
+ def callback(msg, t):
+ received.append(msg)
+ event.set()
+
+ subscriber.subscribe(topic, callback)
+ publisher.publish(topic, original)
+
+ assert event.wait(timeout=2.0), "No PointStamped message received"
+ assert len(received) == 1
+
+ converted = received[0]
+
+ # Verify nested header fields are preserved
+ assert converted.header.frame_id == original.header.frame_id
+ assert converted.header.stamp.sec == original.header.stamp.sec
+ assert converted.header.stamp.nsec == original.header.stamp.nsec
+
+ # Verify point coordinates are preserved
+ assert converted.point.x == original.point.x
+ assert converted.point.y == original.point.y
+ assert converted.point.z == original.point.z
+
+
+@pytest.mark.ros
+def test_twist_pubsub(publisher, subscriber):
+ """Test Twist publish/subscribe through ROS.
+
+ dimos.msgs type with nested Vector3 messages (linear, angular).
+ Tests recursive field copy with custom dimos.msgs nested types.
+ """
+ original = Twist(
+ linear=[1.0, 2.0, 3.0],
+ angular=[0.1, 0.2, 0.3],
+ )
+
+ topic = ROSTopic("/test_twist", Twist)
+
+ received = []
+ event = threading.Event()
+
+ def callback(msg, t):
+ received.append(msg)
+ event.set()
+
+ subscriber.subscribe(topic, callback)
+ publisher.publish(topic, original)
+
+ assert event.wait(timeout=2.0), "No Twist message received"
+ assert len(received) == 1
+
+ converted = received[0]
+
+ # Verify linear velocity preserved
+ assert converted.linear.x == original.linear.x
+ assert converted.linear.y == original.linear.y
+ assert converted.linear.z == original.linear.z
+
+ # Verify angular velocity preserved
+ assert converted.angular.x == original.angular.x
+ assert converted.angular.y == original.angular.y
+ assert converted.angular.z == original.angular.z
diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py
index 91e8514b70..357e1dfa1e 100644
--- a/dimos/protocol/pubsub/test_spec.py
+++ b/dimos/protocol/pubsub/test_spec.py
@@ -15,7 +15,7 @@
# limitations under the License.
import asyncio
-from collections.abc import Callable
+from collections.abc import Callable, Generator
from contextlib import contextmanager
import time
from typing import Any
@@ -28,7 +28,7 @@
@contextmanager
-def memory_context():
+def memory_context() -> Generator[Memory, None, None]:
"""Context manager for Memory PubSub implementation."""
memory = Memory()
try:
@@ -47,7 +47,7 @@ def memory_context():
from dimos.protocol.pubsub.redispubsub import Redis
@contextmanager
- def redis_context():
+ def redis_context() -> Generator[Redis, None, None]:
redis_pubsub = Redis()
redis_pubsub.start()
yield redis_pubsub
@@ -61,9 +61,54 @@ def redis_context():
# either redis is not installed or the server is not running
print("Redis not available")
+try:
+ from geometry_msgs.msg import Vector3 as ROSVector3
+ from rclpy.qos import (
+ QoSDurabilityPolicy,
+ QoSHistoryPolicy,
+ QoSProfile,
+ QoSReliabilityPolicy,
+ )
+
+ from dimos.protocol.pubsub.rospubsub import RawROS, RawROSTopic
+
+ # Use RELIABLE QoS with larger depth for testing
+ _test_qos = QoSProfile(
+ reliability=QoSReliabilityPolicy.RELIABLE,
+ history=QoSHistoryPolicy.KEEP_ALL,
+ durability=QoSDurabilityPolicy.VOLATILE,
+ depth=5000,
+ )
+
+ @contextmanager
+ def ros_context() -> Generator[RawROS, None, None]:
+ ros_pubsub = RawROS(qos=_test_qos)
+ ros_pubsub.start()
+ time.sleep(0.1)
+ try:
+ yield ros_pubsub
+ finally:
+ ros_pubsub.stop()
+
+ testdata.append(
+ (
+ ros_context,
+ RawROSTopic(topic="/test_ros_topic", ros_type=ROSVector3, qos=_test_qos),
+ [
+ ROSVector3(x=1.0, y=2.0, z=3.0),
+ ROSVector3(x=4.0, y=5.0, z=6.0),
+ ROSVector3(x=7.0, y=8.0, z=9.0),
+ ],
+ )
+ )
+
+except ImportError:
+ # ROS 2 not available
+ print("ROS 2 not available")
+
@contextmanager
-def lcm_context():
+def lcm_context() -> Generator[LCM, None, None]:
lcm_pubsub = LCM(autoconf=True)
lcm_pubsub.start()
yield lcm_pubsub
@@ -83,7 +128,7 @@ def lcm_context():
@contextmanager
-def shared_memory_cpu_context():
+def shared_memory_cpu_context() -> Generator[PickleSharedMemory, None, None]:
shared_mem_pubsub = PickleSharedMemory(prefer="cpu")
shared_mem_pubsub.start()
yield shared_mem_pubsub
@@ -100,13 +145,13 @@ def shared_memory_cpu_context():
@pytest.mark.parametrize("pubsub_context, topic, values", testdata)
-def test_store(pubsub_context, topic, values) -> None:
+def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None:
with pubsub_context() as x:
# Create a list to capture received messages
- received_messages = []
+ received_messages: list[Any] = []
# Define callback function that stores received messages
- def callback(message, _) -> None:
+ def callback(message: Any, _: Any) -> None:
received_messages.append(message)
# Subscribe to the topic with our callback
@@ -125,18 +170,20 @@ def callback(message, _) -> None:
@pytest.mark.parametrize("pubsub_context, topic, values", testdata)
-def test_multiple_subscribers(pubsub_context, topic, values) -> None:
+def test_multiple_subscribers(
+ pubsub_context: Callable[[], Any], topic: Any, values: list[Any]
+) -> None:
"""Test that multiple subscribers receive the same message."""
with pubsub_context() as x:
# Create lists to capture received messages for each subscriber
- received_messages_1 = []
- received_messages_2 = []
+ received_messages_1: list[Any] = []
+ received_messages_2: list[Any] = []
# Define callback functions
- def callback_1(message, topic) -> None:
+ def callback_1(message: Any, topic: Any) -> None:
received_messages_1.append(message)
- def callback_2(message, topic) -> None:
+ def callback_2(message: Any, topic: Any) -> None:
received_messages_2.append(message)
# Subscribe both callbacks to the same topic
@@ -157,14 +204,14 @@ def callback_2(message, topic) -> None:
@pytest.mark.parametrize("pubsub_context, topic, values", testdata)
-def test_unsubscribe(pubsub_context, topic, values) -> None:
+def test_unsubscribe(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None:
"""Test that unsubscribed callbacks don't receive messages."""
with pubsub_context() as x:
# Create a list to capture received messages
- received_messages = []
+ received_messages: list[Any] = []
# Define callback function
- def callback(message, topic) -> None:
+ def callback(message: Any, topic: Any) -> None:
received_messages.append(message)
# Subscribe and get unsubscribe function
@@ -184,14 +231,16 @@ def callback(message, topic) -> None:
@pytest.mark.parametrize("pubsub_context, topic, values", testdata)
-def test_multiple_messages(pubsub_context, topic, values) -> None:
+def test_multiple_messages(
+ pubsub_context: Callable[[], Any], topic: Any, values: list[Any]
+) -> None:
"""Test that subscribers receive multiple messages in order."""
with pubsub_context() as x:
# Create a list to capture received messages
- received_messages = []
+ received_messages: list[Any] = []
# Define callback function
- def callback(message, topic) -> None:
+ def callback(message: Any, topic: Any) -> None:
received_messages.append(message)
# Subscribe to the topic
@@ -212,7 +261,9 @@ def callback(message, topic) -> None:
@pytest.mark.parametrize("pubsub_context, topic, values", testdata)
@pytest.mark.asyncio
-async def test_async_iterator(pubsub_context, topic, values) -> None:
+async def test_async_iterator(
+ pubsub_context: Callable[[], Any], topic: Any, values: list[Any]
+) -> None:
"""Test that async iterator receives messages correctly."""
with pubsub_context() as x:
# Get the messages to send (using the rest of the values)
@@ -260,29 +311,35 @@ async def consume_messages() -> None:
assert received_messages == messages_to_send
+@pytest.mark.integration
@pytest.mark.parametrize("pubsub_context, topic, values", testdata)
-def test_high_volume_messages(pubsub_context, topic, values) -> None:
- """Test that all 5000 messages are received correctly."""
+def test_high_volume_messages(
+ pubsub_context: Callable[[], Any], topic: Any, values: list[Any]
+) -> None:
+ """Test that all 5k messages are received correctly.
+ Limited to 5k because ros transport cannot handle more.
+ Might want to have separate expectations per transport later
+ """
with pubsub_context() as x:
# Create a list to capture received messages
- received_messages = []
+ received_messages: list[Any] = []
last_message_time = [time.time()] # Use list to allow modification in callback
# Define callback function
- def callback(message, topic) -> None:
+ def callback(message: Any, topic: Any) -> None:
received_messages.append(message)
last_message_time[0] = time.time()
# Subscribe to the topic
x.subscribe(topic, callback)
- # Publish 10000 messages
- num_messages = 10000
+ # Publish 5000 messages
+ num_messages = 5000
for _ in range(num_messages):
x.publish(topic, values[0])
# Wait until no messages received for 0.5 seconds
- timeout = 1.0 # Maximum time to wait
+ timeout = 2.0 # Maximum time to wait
stable_duration = 0.1 # Time without new messages to consider done
start_time = time.time()
diff --git a/dimos/protocol/rpc/test_spec.py b/dimos/protocol/rpc/test_spec.py
index 9fb8f65eb7..c29db13703 100644
--- a/dimos/protocol/rpc/test_spec.py
+++ b/dimos/protocol/rpc/test_spec.py
@@ -293,6 +293,7 @@ def callback(val) -> None:
unsub_server()
+@pytest.mark.integration
@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_timeout(rpc_context, impl_name: str) -> None:
"""Test that RPC calls properly timeout."""
diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py
index a0ca8c4796..cf0a0647d8 100644
--- a/dimos/protocol/service/lcmservice.py
+++ b/dimos/protocol/service/lcmservice.py
@@ -16,11 +16,8 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
-from functools import cache
import os
import platform
-import subprocess
-import sys
import threading
import traceback
from typing import Protocol, runtime_checkable
@@ -28,216 +25,46 @@
import lcm
from dimos.protocol.service.spec import Service
+from dimos.protocol.service.system_configurator import (
+ BufferConfiguratorLinux,
+ BufferConfiguratorMacOS,
+ MaxFileConfiguratorMacOS,
+ MulticastConfiguratorLinux,
+ MulticastConfiguratorMacOS,
+ SystemConfigurator,
+ configure_system,
+)
from dimos.utils.logging_config import setup_logger
logger = setup_logger()
+_DEFAULT_LCM_HOST = "239.255.76.67"
+_DEFAULT_LCM_PORT = "7667"
+# LCM_DEFAULT_URL is used by LCM (we didn't pick that env var name)
+_DEFAULT_LCM_URL = os.getenv(
+ "LCM_DEFAULT_URL", f"udpm://{_DEFAULT_LCM_HOST}:{_DEFAULT_LCM_PORT}?ttl=0"
+)
-@cache
-def check_root() -> bool:
- """Return True if the current process is running as root (UID 0)."""
- try:
- return os.geteuid() == 0
- except AttributeError:
- # Platforms without geteuid (e.g. Windows) – assume non-root.
- return False
-
-
-def check_multicast() -> list[str]:
- """Check if multicast configuration is needed and return required commands."""
- commands_needed = []
-
- sudo = "" if check_root() else "sudo "
+def autoconf(check_only: bool = False) -> None:
+ # check multicast and buffer sizes
system = platform.system()
-
+ checks: list[SystemConfigurator] = []
if system == "Linux":
- # Linux commands
- loopback_interface = "lo"
- # Check if loopback interface has multicast enabled
- try:
- result = subprocess.run(
- ["ip", "link", "show", loopback_interface], capture_output=True, text=True
- )
- if "MULTICAST" not in result.stdout:
- commands_needed.append(f"{sudo}ifconfig {loopback_interface} multicast")
- except Exception:
- commands_needed.append(f"{sudo}ifconfig {loopback_interface} multicast")
-
- # Check if multicast route exists
- try:
- result = subprocess.run(
- ["ip", "route", "show", "224.0.0.0/4"], capture_output=True, text=True
- )
- if not result.stdout.strip():
- commands_needed.append(
- f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev {loopback_interface}"
- )
- except Exception:
- commands_needed.append(
- f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev {loopback_interface}"
- )
-
- elif system == "Darwin": # macOS
- loopback_interface = "lo0"
- # Check if multicast route exists
- try:
- result = subprocess.run(["netstat", "-nr"], capture_output=True, text=True)
- route_exists = "224.0.0.0/4" in result.stdout or "224.0.0/4" in result.stdout
- if not route_exists:
- commands_needed.append(
- f"{sudo}route add -net 224.0.0.0/4 -interface {loopback_interface}"
- )
- except Exception:
- commands_needed.append(
- f"{sudo}route add -net 224.0.0.0/4 -interface {loopback_interface}"
- )
-
+ checks = [
+ MulticastConfiguratorLinux(loopback_interface="lo"),
+ BufferConfiguratorLinux(),
+ ]
+ elif system == "Darwin":
+ checks = [
+ MulticastConfiguratorMacOS(loopback_interface="lo0"),
+ BufferConfiguratorMacOS(),
+ MaxFileConfiguratorMacOS(),
+ ]
else:
- # For other systems, skip multicast configuration
- logger.warning(f"Multicast configuration not supported on {system}")
-
- return commands_needed
-
-
-def _set_net_value(commands_needed: list[str], sudo: str, name: str, value: int) -> int | None:
- try:
- result = subprocess.run(["sysctl", name], capture_output=True, text=True)
- if result.returncode == 0:
- current = int(result.stdout.replace(":", "=").split("=")[1].strip())
- else:
- current = None
- if not current or current < value:
- commands_needed.append(f"{sudo}sysctl -w {name}={value}")
- return current
- except:
- commands_needed.append(f"{sudo}sysctl -w {name}={value}")
- return None
-
-
-TARGET_RMEM_SIZE = 2097152 # prev was 67108864
-TARGET_MAX_SOCKET_BUFFER_SIZE_MACOS = 8388608
-TARGET_MAX_DGRAM_SIZE_MACOS = 65535
-
-
-def check_buffers() -> tuple[list[str], int | None]:
- """Check if buffer configuration is needed and return required commands and current size.
-
- Returns:
- Tuple of (commands_needed, current_max_buffer_size)
- """
- commands_needed: list[str] = []
- current_max = None
-
- sudo = "" if check_root() else "sudo "
- system = platform.system()
-
- if system == "Linux":
- # Linux buffer configuration
- current_max = _set_net_value(commands_needed, sudo, "net.core.rmem_max", TARGET_RMEM_SIZE)
- _set_net_value(commands_needed, sudo, "net.core.rmem_default", TARGET_RMEM_SIZE)
- elif system == "Darwin": # macOS
- # macOS buffer configuration - check and set UDP buffer related sysctls
- current_max = _set_net_value(
- commands_needed, sudo, "kern.ipc.maxsockbuf", TARGET_MAX_SOCKET_BUFFER_SIZE_MACOS
- )
- _set_net_value(commands_needed, sudo, "net.inet.udp.recvspace", TARGET_RMEM_SIZE)
- _set_net_value(commands_needed, sudo, "net.inet.udp.maxdgram", TARGET_MAX_DGRAM_SIZE_MACOS)
- else:
- # For other systems, skip buffer configuration
- logger.warning(f"Buffer configuration not supported on {system}")
-
- return commands_needed, current_max
-
-
-def check_system() -> None:
- """Check if system configuration is needed and exit only for critical issues.
-
- Multicast configuration is critical for LCM to work.
- Buffer sizes are performance optimizations - warn but don't fail in containers.
- """
- if os.environ.get("CI"):
- logger.debug("CI environment detected: Skipping system configuration checks.")
- return
-
- multicast_commands = check_multicast()
- buffer_commands, current_buffer_size = check_buffers()
-
- # Check multicast first - this is critical
- if multicast_commands:
- logger.error(
- "Critical: Multicast configuration required. Please run the following commands:"
- )
- for cmd in multicast_commands:
- logger.error(f" {cmd}")
- logger.error("\nThen restart your application.")
- sys.exit(1)
-
- # Buffer configuration is just for performance
- elif buffer_commands:
- if current_buffer_size:
- logger.warning(
- f"UDP buffer size limited to {current_buffer_size} bytes ({current_buffer_size // 1024}KB). Large LCM packets may fail."
- )
- else:
- logger.warning("UDP buffer sizes are limited. Large LCM packets may fail.")
- logger.warning("For better performance, consider running:")
- for cmd in buffer_commands:
- logger.warning(f" {cmd}")
- logger.warning("Note: This may not be possible in Docker containers.")
-
-
-def autoconf() -> None:
- """Auto-configure system by running checks and executing required commands if needed."""
- if os.environ.get("CI"):
- logger.info("CI environment detected: Skipping automatic system configuration.")
- return
-
- platform.system()
-
- commands_needed = []
-
- # Check multicast configuration
- commands_needed.extend(check_multicast())
-
- # Check buffer configuration
- buffer_commands, _ = check_buffers()
- commands_needed.extend(buffer_commands)
-
- if not commands_needed:
+ logger.error(f"System configuration not supported on {system}")
return
-
- logger.info("System configuration required. Executing commands...")
-
- for cmd in commands_needed:
- logger.info(f" Running: {cmd}")
- try:
- # Split command into parts for subprocess
- cmd_parts = cmd.split()
- subprocess.run(cmd_parts, capture_output=True, text=True, check=True)
- logger.info(" ✓ Success")
- except subprocess.CalledProcessError as e:
- # Check if this is a multicast/route command or a sysctl command
- if "route" in cmd or "multicast" in cmd:
- # Multicast/route failures should still fail
- logger.error(f" ✗ Failed to configure multicast: {e}")
- logger.error(f" stdout: {e.stdout}")
- logger.error(f" stderr: {e.stderr}")
- raise
- elif "sysctl" in cmd:
- # Sysctl failures are just warnings (likely docker/container)
- logger.warning(
- f" ✗ Not able to auto-configure UDP buffer sizes (likely docker image): {e}"
- )
- except Exception as e:
- logger.error(f" ✗ Error: {e}")
- if "route" in cmd or "multicast" in cmd:
- raise
-
- logger.info("System configuration completed.")
-
-
-_DEFAULT_LCM_URL_MACOS = "udpm://239.255.76.67:7667?ttl=0"
+ configure_system(checks, check_only=check_only)
@dataclass
@@ -248,9 +75,8 @@ class LCMConfig:
lcm: lcm.LCM | None = None
def __post_init__(self) -> None:
- if self.url is None and platform.system() == "Darwin":
- # On macOS, use multicast with TTL=0 to keep traffic local
- self.url = _DEFAULT_LCM_URL_MACOS
+ if self.url is None:
+ self.url = _DEFAULT_LCM_URL
@runtime_checkable
@@ -334,13 +160,10 @@ def start(self) -> None:
else:
self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM()
- if self.config.autoconf:
- autoconf()
- else:
- try:
- check_system()
- except Exception as e:
- print(f"Error checking system configuration: {e}")
+ try:
+ autoconf(check_only=not self.config.autoconf)
+ except Exception as e:
+ print(f"Error checking system configuration: {e}")
self._stop_event.clear()
self._thread = threading.Thread(target=self._lcm_loop)
diff --git a/dimos/protocol/service/system_configurator.py b/dimos/protocol/service/system_configurator.py
new file mode 100644
index 0000000000..44b8c45276
--- /dev/null
+++ b/dimos/protocol/service/system_configurator.py
@@ -0,0 +1,436 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from functools import cache
+import os
+import re
+import resource
+import subprocess
+from typing import Any
+
+# ----------------------------- sudo helpers -----------------------------
+
+
+@cache
+def _is_root_user() -> bool:
+ try:
+ return os.geteuid() == 0
+ except AttributeError:
+ return False
+
+
+def sudo_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess[str]:
+ if _is_root_user():
+ return subprocess.run(list(args), **kwargs)
+ return subprocess.run(["sudo", *args], **kwargs)
+
+
+def _read_sysctl_int(name: str) -> int | None:
+ try:
+ result = subprocess.run(["sysctl", name], capture_output=True, text=True)
+ if result.returncode != 0:
+ print(
+ f"[sysctl] ERROR: `sysctl {name}` rc={result.returncode} stderr={result.stderr!r}"
+ )
+ return None
+
+ text = result.stdout.strip().replace(":", "=")
+ if "=" not in text:
+ print(f"[sysctl] ERROR: unexpected output for {name}: {text!r}")
+ return None
+
+ return int(text.split("=", 1)[1].strip())
+ except Exception as error:
+ print(f"[sysctl] ERROR: reading {name}: {error}")
+ return None
+
+
+def _write_sysctl_int(name: str, value: int) -> None:
+ sudo_run("sysctl", "-w", f"{name}={value}", check=True, text=True, capture_output=False)
+
+
+# -------------------------- base class for system config checks/requirements --------------------------
+
+
+class SystemConfigurator(ABC):
+ critical: bool = False
+
+ @abstractmethod
+ def check(self) -> bool:
+ """Return True if configured. Log errors and return False on uncertainty."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def explanation(self) -> str | None:
+ """
+ Return a human-readable summary of what would be done (sudo commands) if not configured.
+ Return None when no changes are needed.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def fix(self) -> None:
+ """Apply fixes (may attempt sudo, catch, and apply fallback measures if needed)."""
+ raise NotImplementedError
+
+
+# ----------------------------- generic enforcement of system configs -----------------------------
+
+
+def configure_system(checks: list[SystemConfigurator], check_only: bool = False) -> None:
+ if os.environ.get("CI"):
+ print("CI environment detected: skipping system configuration.")
+ return
+
+ # run checks
+ failing = [check for check in checks if not check.check()]
+ if not failing:
+ return
+
+ # ask for permission to modify system
+ explanations: list[str] = [msg for check in failing if (msg := check.explanation()) is not None]
+
+ if explanations:
+ print("System configuration changes are recommended/required:\n")
+ print("\n\n".join(explanations))
+ print()
+
+ if check_only:
+ return
+
+ try:
+ answer = input("Apply these changes now? [y/N]: ").strip().lower()
+ except (EOFError, KeyboardInterrupt):
+ answer = ""
+
+ if answer not in ("y", "yes"):
+ if any(check.critical for check in failing):
+ raise SystemExit(1)
+ return
+
+ for check in failing:
+ try:
+ check.fix()
+ except subprocess.CalledProcessError as error:
+ if check.critical:
+ print(f"Critical fix failed rc={error.returncode}")
+ print(f"stdout: {error.stdout}")
+ print(f"stderr: {error.stderr}")
+ raise
+ print(f"Optional improvement failed: rc={error.returncode}")
+ print(f"stdout: {error.stdout}")
+ print(f"stderr: {error.stderr}")
+
+ print("System configuration completed.")
+
+
+# ------------------------------ specific checks: multicast ------------------------------
+
+
+class MulticastConfiguratorLinux(SystemConfigurator):
+ critical = True
+ MULTICAST_PREFIX = "224.0.0.0/4"
+
+ def __init__(self, loopback_interface: str = "lo"):
+ self.loopback_interface = loopback_interface
+
+ self.loopback_ok: bool | None = None
+ self.route_ok: bool | None = None
+
+ self.enable_multicast_cmd = [
+ "ip",
+ "link",
+ "set",
+ self.loopback_interface,
+ "multicast",
+ "on",
+ ]
+ self.add_route_cmd = [
+ "ip",
+ "route",
+ "add",
+ self.MULTICAST_PREFIX,
+ "dev",
+ self.loopback_interface,
+ ]
+
+ def check(self) -> bool:
+ # Verify `ip` exists (iproute2)
+ try:
+ subprocess.run(["ip", "-V"], capture_output=True, text=True, check=False)
+ except FileNotFoundError as error:
+ print(
+ f"ERROR: `ip` not found (iproute2 missing, did you install system requirements?): {error}"
+ )
+ self.loopback_ok = self.route_ok = False
+ return False
+ except Exception as error:
+ print(f"ERROR: failed probing `ip`: {error}")
+ self.loopback_ok = self.route_ok = False
+ return False
+
+ # check MULTICAST on loopback
+ try:
+ result = subprocess.run(
+ ["ip", "-o", "link", "show", self.loopback_interface],
+ capture_output=True,
+ text=True,
+ )
+ if result.returncode != 0:
+ print(
+ f"ERROR: `ip link show {self.loopback_interface}` rc={result.returncode} "
+ f"stderr={result.stderr!r}"
+ )
+ self.loopback_ok = False
+ else:
+ match = re.search(r"<([^>]*)>", result.stdout)
+ flags = {
+ flag.strip().upper()
+ for flag in (match.group(1).split(",") if match else [])
+ if flag.strip()
+ }
+ self.loopback_ok = "MULTICAST" in flags
+ except Exception as error:
+ print(f"ERROR: failed checking loopback multicast: {error}")
+ self.loopback_ok = False
+
+ # Check if multicast route exists
+ try:
+ result = subprocess.run(
+ ["ip", "-o", "route", "show", self.MULTICAST_PREFIX],
+ capture_output=True,
+ text=True,
+ )
+ if result.returncode != 0:
+ print(
+ f"ERROR: `ip route show {self.MULTICAST_PREFIX}` rc={result.returncode} "
+ f"stderr={result.stderr!r}"
+ )
+ self.route_ok = False
+ else:
+ self.route_ok = bool(result.stdout.strip())
+ except Exception as error:
+ print(f"ERROR: failed checking multicast route: {error}")
+ self.route_ok = False
+
+ return bool(self.loopback_ok and self.route_ok)
+
+ def explanation(self) -> str | None:
+ output = ""
+ if not self.loopback_ok:
+ output += f"- Multicast: sudo {' '.join(self.enable_multicast_cmd)}\n"
+ if not self.route_ok:
+ output += f"- Multicast: sudo {' '.join(self.add_route_cmd)}\n"
+ return output
+
+ def fix(self) -> None:
+ if not self.loopback_ok:
+ sudo_run(*self.enable_multicast_cmd, check=True, text=True, capture_output=True)
+ if not self.route_ok:
+ sudo_run(*self.add_route_cmd, check=True, text=True, capture_output=True)
+
+
+class MulticastConfiguratorMacOS(SystemConfigurator):
+ critical = True
+
+ def __init__(self, loopback_interface: str = "lo0"):
+ self.loopback_interface = loopback_interface
+ self.add_route_cmd = [
+ "route",
+ "add",
+ "-net",
+ "224.0.0.0/4",
+ "-interface",
+ self.loopback_interface,
+ ]
+
+ def check(self) -> bool:
+ # `netstat -nr` shows the routing table. We search for a 224/4 route entry.
+ try:
+ result = subprocess.run(["netstat", "-nr"], capture_output=True, text=True)
+ if result.returncode != 0:
+ print(f"ERROR: `netstat -nr` rc={result.returncode} stderr={result.stderr!r}")
+ return False
+
+ route_ok = ("224.0.0.0/4" in result.stdout) or ("224.0.0/4" in result.stdout)
+ return bool(route_ok)
+ except Exception as error:
+ print(f"ERROR: failed checking multicast route via netstat: {error}")
+ return False
+
+ def explanation(self) -> str | None:
+ return f"Multicast: - sudo {' '.join(self.add_route_cmd)}"
+
+ def fix(self) -> None:
+ sudo_run(*self.add_route_cmd, check=True, text=True, capture_output=True)
+
+
+# ------------------------------ specific checks: buffers ------------------------------
+
+IDEAL_RMEM_SIZE = 67_108_864 # 64MB
+
+
+class BufferConfiguratorLinux(SystemConfigurator):
+ critical = False
+
+ TARGET_RMEM_SIZE = IDEAL_RMEM_SIZE
+
+ def __init__(self) -> None:
+ self.needs: list[tuple[str, int]] = [] # (key, target_value)
+
+ def check(self) -> bool:
+ self.needs.clear()
+ for key, target in [
+ ("net.core.rmem_max", self.TARGET_RMEM_SIZE),
+ ("net.core.rmem_default", self.TARGET_RMEM_SIZE),
+ ]:
+ current = _read_sysctl_int(key)
+ if current is None or current < target:
+ self.needs.append((key, target))
+ return not self.needs
+
+ def explanation(self) -> str | None:
+ lines = []
+ for key, target in self.needs:
+ lines.append(f"- socket buffer optimization: sudo sysctl -w {key}={target}")
+ return "\n".join(lines)
+
+ def fix(self) -> None:
+ for key, target in self.needs:
+ _write_sysctl_int(key, target)
+
+
+class BufferConfiguratorMacOS(SystemConfigurator):
+ critical = False
+ MAX_POSSIBLE_RECVSPACE = 2_097_152
+ MAX_POSSIBLE_BUFFER_SIZE = 8_388_608
+ MAX_POSSIBLE_DGRAM_SIZE = 65_535
+ # these values are based on macos 26
+
+ TARGET_BUFFER_SIZE = MAX_POSSIBLE_BUFFER_SIZE
+ TARGET_RECVSPACE = MAX_POSSIBLE_RECVSPACE # we want this to be IDEAL_RMEM_SIZE but MacOS 26 (and probably in general) doesn't support it
+ TARGET_DGRAM_SIZE = MAX_POSSIBLE_DGRAM_SIZE
+
+ def __init__(self) -> None:
+ self.needs: list[tuple[str, int]] = []
+
+ def check(self) -> bool:
+ self.needs.clear()
+ for key, target in [
+ ("kern.ipc.maxsockbuf", self.TARGET_BUFFER_SIZE),
+ ("net.inet.udp.recvspace", self.TARGET_RECVSPACE),
+ ("net.inet.udp.maxdgram", self.TARGET_DGRAM_SIZE),
+ ]:
+ current = _read_sysctl_int(key)
+ if current is None or current < target:
+ self.needs.append((key, target))
+ return not self.needs
+
+ def explanation(self) -> str | None:
+ lines = []
+ for key, target in self.needs:
+ lines.append(f"- sudo sysctl -w {key}={target}")
+ return "\n".join(lines)
+
+ def fix(self) -> None:
+ for key, target in self.needs:
+ _write_sysctl_int(key, target)
+
+
+# ------------------------------ specific checks: ulimit ------------------------------
+
+
+class MaxFileConfiguratorMacOS(SystemConfigurator):
+ """Ensure the open file descriptor limit (ulimit -n) is at least TARGET_FILE_COUNT_LIMIT."""
+
+ critical = False
+ TARGET_FILE_COUNT_LIMIT = 65536
+
+ def __init__(self, target: int = TARGET_FILE_COUNT_LIMIT):
+ self.target = target
+ self.current_soft: int = 0
+ self.current_hard: int = 0
+ self.can_fix_without_sudo: bool = False
+
+ def check(self) -> bool:
+ try:
+ self.current_soft, self.current_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+ except Exception as error:
+ print(f"[ulimit] ERROR: failed to get RLIMIT_NOFILE: {error}")
+ return False
+
+ if self.current_soft >= self.target:
+ return True
+
+ # Check if we can raise to target without sudo (hard limit is high enough)
+ self.can_fix_without_sudo = self.current_hard >= self.target
+ return False
+
+ def explanation(self) -> str | None:
+ lines = []
+ if self.can_fix_without_sudo:
+ lines.append(f"- Raise soft file count limit to {self.target} (no sudo required)")
+ else:
+ lines.append(f"- Raise soft file count limit to {min(self.target, self.current_hard)}")
+ lines.append(
+ f"- Raise hard limit via: sudo launchctl limit maxfiles {self.target} {self.target}"
+ )
+ return "\n".join(lines)
+
+ def fix(self) -> None:
+ if self.current_soft >= self.target:
+ return
+
+ if self.can_fix_without_sudo:
+ # Hard limit is sufficient, just raise the soft limit
+ try:
+ resource.setrlimit(resource.RLIMIT_NOFILE, (self.target, self.current_hard))
+ except Exception as error:
+ print(f"[ulimit] ERROR: failed to set soft limit: {error}")
+ raise
+ else:
+ # Need to raise both soft and hard limits via launchctl
+ try:
+ sudo_run(
+ "launchctl",
+ "limit",
+ "maxfiles",
+ str(self.target),
+ str(self.target),
+ check=True,
+ text=True,
+ capture_output=True,
+ )
+ except subprocess.CalledProcessError as error:
+ print(f"[ulimit] WARNING: launchctl failed: {error.stderr}")
+ # Fallback: raise soft limit as high as the current hard limit allows
+ if self.current_hard > self.current_soft:
+ try:
+ resource.setrlimit(
+ resource.RLIMIT_NOFILE, (self.current_hard, self.current_hard)
+ )
+ except Exception as fallback_error:
+ print(f"[ulimit] ERROR: fallback also failed: {fallback_error}")
+ raise
+
+ # After launchctl, try to apply the new limit to the current process
+ try:
+ resource.setrlimit(resource.RLIMIT_NOFILE, (self.target, self.target))
+ except Exception as error:
+ print(
+ f"[ulimit] WARNING: could not apply to current process (restart may be required): {error}"
+ )
diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py
index faf50a945e..a85462cf31 100644
--- a/dimos/protocol/service/test_lcmservice.py
+++ b/dimos/protocol/service/test_lcmservice.py
@@ -13,555 +13,306 @@
# limitations under the License.
import os
-import subprocess
-from unittest.mock import patch
+import pickle
+import threading
+import time
+from unittest.mock import MagicMock, patch
import pytest
from dimos.protocol.service.lcmservice import (
- TARGET_MAX_DGRAM_SIZE_MACOS,
- TARGET_MAX_SOCKET_BUFFER_SIZE_MACOS,
- TARGET_RMEM_SIZE,
+ _DEFAULT_LCM_URL,
+ LCMConfig,
+ LCMService,
+ Topic,
autoconf,
- check_buffers,
- check_multicast,
- check_root,
+)
+from dimos.protocol.service.system_configurator import (
+ BufferConfiguratorLinux,
+ BufferConfiguratorMacOS,
+ MaxFileConfiguratorMacOS,
+ MulticastConfiguratorLinux,
+ MulticastConfiguratorMacOS,
)
+# ----------------------------- autoconf tests -----------------------------
-def get_sudo_prefix() -> str:
- """Return 'sudo ' if not running as root, empty string if running as root."""
- return "" if check_root() else "sudo "
-
-
-def test_check_multicast_all_configured() -> None:
- """Test check_multicast when system is properly configured."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock successful checks with realistic output format
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {
- "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00",
- "returncode": 0,
- },
- )(),
- type(
- "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0}
- )(),
- ]
-
- result = check_multicast()
- assert result == []
-
-
-def test_check_multicast_missing_multicast_flag() -> None:
- """Test check_multicast when loopback interface lacks multicast."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock interface without MULTICAST flag (realistic current system state)
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {
- "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00",
- "returncode": 0,
- },
- )(),
- type(
- "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0}
- )(),
- ]
-
- result = check_multicast()
- sudo = get_sudo_prefix()
- assert result == [f"{sudo}ifconfig lo multicast"]
-
-
-def test_check_multicast_missing_route() -> None:
- """Test check_multicast when multicast route is missing."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock missing route - interface has multicast but no route
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {
- "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00",
- "returncode": 0,
- },
- )(),
- type(
- "MockResult", (), {"stdout": "", "returncode": 0}
- )(), # Empty output - no route
- ]
-
- result = check_multicast()
- sudo = get_sudo_prefix()
- assert result == [f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"]
-
-
-def test_check_multicast_all_missing() -> None:
- """Test check_multicast when both multicast flag and route are missing (current system state)."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock both missing - matches actual current system state
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {
- "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00",
- "returncode": 0,
- },
- )(),
- type(
- "MockResult", (), {"stdout": "", "returncode": 0}
- )(), # Empty output - no route
- ]
-
- result = check_multicast()
- sudo = get_sudo_prefix()
- expected = [
- f"{sudo}ifconfig lo multicast",
- f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo",
- ]
- assert result == expected
-
-
-def test_check_multicast_subprocess_exception() -> None:
- """Test check_multicast when subprocess calls fail."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock subprocess exceptions
- mock_run.side_effect = Exception("Command failed")
-
- result = check_multicast()
- sudo = get_sudo_prefix()
- expected = [
- f"{sudo}ifconfig lo multicast",
- f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo",
- ]
- assert result == expected
-
-
-def test_check_multicast_macos() -> None:
- """Test check_multicast on macOS when configuration is needed."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Darwin"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock netstat -nr to not contain the multicast route
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {
- "stdout": "default 192.168.1.1 UGScg en0",
- "returncode": 0,
- },
- )(),
- ]
-
- result = check_multicast()
- sudo = get_sudo_prefix()
- expected = [f"{sudo}route add -net 224.0.0.0/4 -interface lo0"]
- assert result == expected
-
-
-def test_check_buffers_all_configured() -> None:
- """Test check_buffers when system is properly configured."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock sufficient buffer sizes
- mock_run.side_effect = [
- type(
- "MockResult", (), {"stdout": "net.core.rmem_max = 67108864", "returncode": 0}
- )(),
- type(
- "MockResult",
- (),
- {"stdout": "net.core.rmem_default = 16777216", "returncode": 0},
- )(),
- ]
-
- commands, buffer_size = check_buffers()
- assert commands == []
- assert buffer_size >= TARGET_RMEM_SIZE
-
-
-def test_check_buffers_low_max_buffer() -> None:
- """Test check_buffers when rmem_max is too low."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock low rmem_max
- mock_run.side_effect = [
- type(
- "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0}
- )(),
- type(
- "MockResult",
- (),
- {"stdout": f"net.core.rmem_default = {TARGET_RMEM_SIZE}", "returncode": 0},
- )(),
- ]
-
- commands, buffer_size = check_buffers()
- sudo = get_sudo_prefix()
- assert commands == [f"{sudo}sysctl -w net.core.rmem_max={TARGET_RMEM_SIZE}"]
- assert buffer_size == 1048576
-
-
-def test_check_buffers_low_default_buffer() -> None:
- """Test check_buffers when rmem_default is too low."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock low rmem_default
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {"stdout": f"net.core.rmem_max = {TARGET_RMEM_SIZE}", "returncode": 0},
- )(),
- type(
- "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0}
- )(),
- ]
-
- commands, buffer_size = check_buffers()
- sudo = get_sudo_prefix()
- assert commands == [f"{sudo}sysctl -w net.core.rmem_default={TARGET_RMEM_SIZE}"]
- assert buffer_size == TARGET_RMEM_SIZE
-
-
-def test_check_buffers_both_low() -> None:
- """Test check_buffers when both buffer sizes are too low."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock both low
- mock_run.side_effect = [
- type(
- "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0}
- )(),
- type(
- "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0}
- )(),
- ]
-
- commands, buffer_size = check_buffers()
- sudo = get_sudo_prefix()
- expected = [
- f"{sudo}sysctl -w net.core.rmem_max={TARGET_RMEM_SIZE}",
- f"{sudo}sysctl -w net.core.rmem_default={TARGET_RMEM_SIZE}",
- ]
- assert commands == expected
- assert buffer_size == 1048576
-
-
-def test_check_buffers_subprocess_exception() -> None:
- """Test check_buffers when subprocess calls fail."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock subprocess exceptions
- mock_run.side_effect = Exception("Command failed")
-
- commands, buffer_size = check_buffers()
- sudo = get_sudo_prefix()
- expected = [
- f"{sudo}sysctl -w net.core.rmem_max={TARGET_RMEM_SIZE}",
- f"{sudo}sysctl -w net.core.rmem_default={TARGET_RMEM_SIZE}",
- ]
- assert commands == expected
- assert buffer_size is None
-
-
-def test_check_buffers_parsing_error() -> None:
- """Test check_buffers when output parsing fails."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock malformed output
- mock_run.side_effect = [
- type("MockResult", (), {"stdout": "invalid output", "returncode": 0})(),
- type("MockResult", (), {"stdout": "also invalid", "returncode": 0})(),
- ]
-
- commands, buffer_size = check_buffers()
- sudo = get_sudo_prefix()
- expected = [
- f"{sudo}sysctl -w net.core.rmem_max={TARGET_RMEM_SIZE}",
- f"{sudo}sysctl -w net.core.rmem_default={TARGET_RMEM_SIZE}",
- ]
- assert commands == expected
- assert buffer_size is None
-
-
-def test_check_buffers_dev_container() -> None:
- """Test check_buffers in dev container where sysctl fails."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock dev container behavior - sysctl returns non-zero
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {
- "stdout": "sysctl: cannot stat /proc/sys/net/core/rmem_max: No such file or directory",
- "returncode": 255,
- },
- )(),
- type(
- "MockResult",
- (),
- {
- "stdout": "sysctl: cannot stat /proc/sys/net/core/rmem_default: No such file or directory",
- "returncode": 255,
- },
- )(),
- ]
-
- commands, buffer_size = check_buffers()
- sudo = get_sudo_prefix()
- expected = [
- f"{sudo}sysctl -w net.core.rmem_max={TARGET_RMEM_SIZE}",
- f"{sudo}sysctl -w net.core.rmem_default={TARGET_RMEM_SIZE}",
- ]
- assert commands == expected
- assert buffer_size is None
-
-
-def test_check_buffers_macos_all_configured() -> None:
- """Test check_buffers on macOS when system is properly configured."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Darwin"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock sufficient buffer sizes for macOS
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {
- "stdout": f"kern.ipc.maxsockbuf: {TARGET_MAX_SOCKET_BUFFER_SIZE_MACOS}",
- "returncode": 0,
- },
- )(),
- type(
- "MockResult",
- (),
- {"stdout": f"net.inet.udp.recvspace: {TARGET_RMEM_SIZE}", "returncode": 0},
- )(),
- type(
- "MockResult",
- (),
- {
- "stdout": f"net.inet.udp.maxdgram: {TARGET_MAX_DGRAM_SIZE_MACOS}",
- "returncode": 0,
- },
- )(),
- ]
-
- commands, buffer_size = check_buffers()
- assert commands == []
- assert buffer_size == TARGET_MAX_SOCKET_BUFFER_SIZE_MACOS
-
-
-def test_check_buffers_macos_needs_config() -> None:
- """Test check_buffers on macOS when configuration is needed."""
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Darwin"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- mock_max_sock_buf_size = 4194304
- # Mock low buffer sizes for macOS
- mock_run.side_effect = [
- type(
- "MockResult",
- (),
- {"stdout": f"kern.ipc.maxsockbuf: {mock_max_sock_buf_size}", "returncode": 0},
- )(),
- type(
- "MockResult", (), {"stdout": "net.inet.udp.recvspace: 1048576", "returncode": 0}
- )(),
- type(
- "MockResult", (), {"stdout": "net.inet.udp.maxdgram: 32768", "returncode": 0}
- )(),
- ]
-
- commands, buffer_size = check_buffers()
- sudo = get_sudo_prefix()
- expected = [
- f"{sudo}sysctl -w kern.ipc.maxsockbuf={TARGET_MAX_SOCKET_BUFFER_SIZE_MACOS}",
- f"{sudo}sysctl -w net.inet.udp.recvspace={TARGET_RMEM_SIZE}",
- f"{sudo}sysctl -w net.inet.udp.maxdgram={TARGET_MAX_DGRAM_SIZE_MACOS}",
- ]
- assert commands == expected
- assert buffer_size == mock_max_sock_buf_size
-
-
-def test_autoconf_no_config_needed() -> None:
- """Test autoconf when no configuration is needed."""
- # Clear CI environment variable for this test
- with patch.dict(os.environ, {"CI": ""}, clear=False):
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock all checks passing
- mock_run.side_effect = [
- # check_multicast calls
- type(
- "MockResult",
- (),
- {
- "stdout": "1: lo: mtu 65536",
- "returncode": 0,
- },
- )(),
- type(
- "MockResult",
- (),
- {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0},
- )(),
- # check_buffers calls
- type(
- "MockResult",
- (),
- {"stdout": f"net.core.rmem_max = {TARGET_RMEM_SIZE}", "returncode": 0},
- )(),
- type(
- "MockResult",
- (),
- {"stdout": f"net.core.rmem_default = {TARGET_RMEM_SIZE}", "returncode": 0},
- )(),
- ]
+class TestConfigureSystemForLcm:
+ def test_creates_linux_checks_on_linux(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
+ with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure:
+ autoconf()
+ mock_configure.assert_called_once()
+ checks = mock_configure.call_args[0][0]
+ assert len(checks) == 2
+ assert isinstance(checks[0], MulticastConfiguratorLinux)
+ assert isinstance(checks[1], BufferConfiguratorLinux)
+ assert checks[0].loopback_interface == "lo"
+
+ def test_creates_macos_checks_on_darwin(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Darwin"):
+ with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure:
+ autoconf()
+ mock_configure.assert_called_once()
+ checks = mock_configure.call_args[0][0]
+ assert len(checks) == 3
+ assert isinstance(checks[0], MulticastConfiguratorMacOS)
+ assert isinstance(checks[1], BufferConfiguratorMacOS)
+ assert isinstance(checks[2], MaxFileConfiguratorMacOS)
+ assert checks[0].loopback_interface == "lo0"
+
+ def test_passes_check_only_flag(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
+ with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure:
+ autoconf(check_only=True)
+ mock_configure.assert_called_once()
+ assert mock_configure.call_args[1]["check_only"] is True
+
+ def test_logs_error_on_unsupported_system(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Windows"):
+ with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure:
with patch("dimos.protocol.service.lcmservice.logger") as mock_logger:
autoconf()
- # Should not log anything when no config is needed
- mock_logger.info.assert_not_called()
- mock_logger.error.assert_not_called()
- mock_logger.warning.assert_not_called()
+ mock_configure.assert_not_called()
+ mock_logger.error.assert_called_once()
+ assert "Windows" in mock_logger.error.call_args[0][0]
-def test_autoconf_with_config_needed_success() -> None:
- """Test autoconf when configuration is needed and commands succeed."""
- # Clear CI environment variable for this test
- with patch.dict(os.environ, {"CI": ""}, clear=False):
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock checks failing, then mock the execution succeeding
- mock_run.side_effect = [
- # check_multicast calls
- type(
- "MockResult",
- (),
- {"stdout": "1: lo: mtu 65536", "returncode": 0},
- )(),
- type("MockResult", (), {"stdout": "", "returncode": 0})(),
- # check_buffers calls
- type(
- "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0}
- )(),
- type(
- "MockResult",
- (),
- {"stdout": "net.core.rmem_default = 1048576", "returncode": 0},
- )(),
- # Command execution calls
- type(
- "MockResult", (), {"stdout": "success", "returncode": 0}
- )(), # ifconfig lo multicast
- type(
- "MockResult", (), {"stdout": "success", "returncode": 0}
- )(), # route add...
- type(
- "MockResult", (), {"stdout": "success", "returncode": 0}
- )(), # sysctl rmem_max
- type(
- "MockResult", (), {"stdout": "success", "returncode": 0}
- )(), # sysctl rmem_default
- ]
-
- from unittest.mock import call
+# ----------------------------- LCMConfig tests -----------------------------
- with patch("dimos.protocol.service.lcmservice.logger") as mock_logger:
- autoconf()
- sudo = get_sudo_prefix()
- # Verify the expected log calls
- expected_info_calls = [
- call("System configuration required. Executing commands..."),
- call(f" Running: {sudo}ifconfig lo multicast"),
- call(" ✓ Success"),
- call(f" Running: {sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"),
- call(" ✓ Success"),
- call(f" Running: {sudo}sysctl -w net.core.rmem_max={TARGET_RMEM_SIZE}"),
- call(" ✓ Success"),
- call(
- f" Running: {sudo}sysctl -w net.core.rmem_default={TARGET_RMEM_SIZE}"
- ),
- call(" ✓ Success"),
- call("System configuration completed."),
- ]
-
- mock_logger.info.assert_has_calls(expected_info_calls)
-
-
-def test_autoconf_with_command_failures() -> None:
- """Test autoconf when some commands fail."""
- # Clear CI environment variable for this test
- with patch.dict(os.environ, {"CI": ""}, clear=False):
- with patch("dimos.protocol.service.lcmservice.platform.system", return_value="Linux"):
- with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run:
- # Mock checks failing, then mock some commands failing
- mock_run.side_effect = [
- # check_multicast calls
- type(
- "MockResult",
- (),
- {"stdout": "1: lo: mtu 65536", "returncode": 0},
- )(),
- type("MockResult", (), {"stdout": "", "returncode": 0})(),
- # check_buffers calls (no buffer issues for simpler test)
- type(
- "MockResult",
- (),
- {"stdout": f"net.core.rmem_max = {TARGET_RMEM_SIZE}", "returncode": 0},
- )(),
- type(
- "MockResult",
- (),
- {"stdout": f"net.core.rmem_default = {TARGET_RMEM_SIZE}", "returncode": 0},
- )(),
- # Command execution calls - first succeeds, second fails
- type(
- "MockResult", (), {"stdout": "success", "returncode": 0}
- )(), # ifconfig lo multicast
- subprocess.CalledProcessError(
- 1,
- [
- *get_sudo_prefix().split(),
- "route",
- "add",
- "-net",
- "224.0.0.0",
- "netmask",
- "240.0.0.0",
- "dev",
- "lo",
- ],
- "Permission denied",
- "Operation not permitted",
- ),
- ]
+class TestLCMConfig:
+ def test_default_values(self) -> None:
+ config = LCMConfig()
+ assert config.ttl == 0
+ assert config.url == _DEFAULT_LCM_URL
+ assert config.autoconf is True
+ assert config.lcm is None
- with patch("dimos.protocol.service.lcmservice.logger") as mock_logger:
- # The function should raise on multicast/route failures
- with pytest.raises(subprocess.CalledProcessError):
- autoconf()
-
- # Verify it logged the failure before raising
- info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
- error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
-
- assert "System configuration required. Executing commands..." in info_calls
- assert " ✓ Success" in info_calls # First command succeeded
- assert any(
- "✗ Failed to configure multicast" in call for call in error_calls
- ) # Second command failed
+ def test_custom_url(self) -> None:
+ custom_url = "udpm://192.168.1.1:7777?ttl=1"
+ config = LCMConfig(url=custom_url)
+ assert config.url == custom_url
+
+ def test_post_init_sets_default_url_when_none(self) -> None:
+ config = LCMConfig(url=None)
+ assert config.url == _DEFAULT_LCM_URL
+
+ def test_autoconf_can_be_disabled(self) -> None:
+ config = LCMConfig(autoconf=False)
+ assert config.autoconf is False
+
+
+# ----------------------------- Topic tests -----------------------------
+
+
+class TestTopic:
+ def test_str_without_lcm_type(self) -> None:
+ topic = Topic(topic="my_topic")
+ assert str(topic) == "my_topic"
+
+ def test_str_with_lcm_type(self) -> None:
+ mock_type = MagicMock()
+ mock_type.msg_name = "TestMessage"
+ topic = Topic(topic="my_topic", lcm_type=mock_type)
+ assert str(topic) == "my_topic#TestMessage"
+
+
+# ----------------------------- LCMService tests -----------------------------
+
+
+class TestLCMService:
+ def test_init_with_default_config(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ service = LCMService()
+ assert service.config.url == _DEFAULT_LCM_URL
+ assert service.l == mock_lcm_instance
+ mock_lcm_class.assert_called_once_with(_DEFAULT_LCM_URL)
+
+ def test_init_with_custom_url(self) -> None:
+ custom_url = "udpm://192.168.1.1:7777?ttl=1"
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ # Pass url as kwarg, not config=
+ LCMService(url=custom_url)
+ mock_lcm_class.assert_called_once_with(custom_url)
+
+ def test_init_with_existing_lcm_instance(self) -> None:
+ mock_lcm_instance = MagicMock()
+
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ # Pass lcm as kwarg
+ service = LCMService(lcm=mock_lcm_instance)
+ mock_lcm_class.assert_not_called()
+ assert service.l == mock_lcm_instance
+
+ def test_start_and_stop(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ with patch("dimos.protocol.service.lcmservice.autoconf"):
+ service = LCMService(autoconf=False)
+ service.start()
+
+ # Verify thread is running
+ assert service._thread is not None
+ assert service._thread.is_alive()
+
+ service.stop()
+
+ # Give the thread a moment to stop
+ time.sleep(0.1)
+ assert not service._thread.is_alive()
+
+ def test_start_calls_configure_system(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ with patch("dimos.protocol.service.lcmservice.autoconf") as mock_configure:
+ service = LCMService(autoconf=True)
+ service.start()
+
+ # With autoconf=True, check_only should be False
+ mock_configure.assert_called_once_with(check_only=False)
+
+ service.stop()
+
+ def test_start_with_autoconf_disabled(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ with patch("dimos.protocol.service.lcmservice.autoconf") as mock_configure:
+ service = LCMService(autoconf=False)
+ service.start()
+
+ # With autoconf=False, check_only should be True
+ mock_configure.assert_called_once_with(check_only=True)
+
+ service.stop()
+
+ def test_getstate_excludes_unpicklable_attrs(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ service = LCMService()
+ state = service.__getstate__()
+
+ assert "l" not in state
+ assert "_stop_event" not in state
+ assert "_thread" not in state
+ assert "_l_lock" not in state
+ assert "_call_thread_pool" not in state
+ assert "_call_thread_pool_lock" not in state
+
+ def test_setstate_reinitializes_runtime_attrs(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ service = LCMService()
+ state = service.__getstate__()
+
+ # Simulate unpickling
+ new_service = object.__new__(LCMService)
+ new_service.__setstate__(state)
+
+ assert new_service.l is None
+ assert isinstance(new_service._stop_event, threading.Event)
+ assert new_service._thread is None
+ # threading.Lock is a factory function, not a type
+ # Just check that the lock exists and has acquire/release methods
+ assert hasattr(new_service._l_lock, "acquire")
+ assert hasattr(new_service._l_lock, "release")
+
+ def test_start_reinitializes_lcm_after_unpickling(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ with patch("dimos.protocol.service.lcmservice.autoconf"):
+ service = LCMService()
+ state = service.__getstate__()
+
+ # Simulate unpickling
+ new_service = object.__new__(LCMService)
+ new_service.__setstate__(state)
+
+ # Start should reinitialize LCM
+ new_service.start()
+
+ # LCM should be created again
+ assert mock_lcm_class.call_count == 2
+
+ new_service.stop()
+
+ def test_stop_cleans_up_lcm_instance(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ with patch("dimos.protocol.service.lcmservice.autoconf"):
+ service = LCMService()
+ service.start()
+ service.stop()
+
+ # LCM instance should be cleaned up when we created it
+ assert service.l is None
+
+ def test_stop_preserves_external_lcm_instance(self) -> None:
+ mock_lcm_instance = MagicMock()
+
+ with patch("dimos.protocol.service.lcmservice.autoconf"):
+ # Pass lcm as kwarg
+ service = LCMService(lcm=mock_lcm_instance)
+ service.start()
+ service.stop()
+
+ # External LCM instance should not be cleaned up
+ assert service.l == mock_lcm_instance
+
+ def test_get_call_thread_pool_creates_pool(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ service = LCMService()
+ assert service._call_thread_pool is None
+
+ pool = service._get_call_thread_pool()
+ assert pool is not None
+ assert service._call_thread_pool == pool
+
+ # Should return same pool on subsequent calls
+ pool2 = service._get_call_thread_pool()
+ assert pool2 is pool
+
+ # Clean up
+ pool.shutdown(wait=False)
+
+ def test_stop_shuts_down_thread_pool(self) -> None:
+ with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class:
+ mock_lcm_instance = MagicMock()
+ mock_lcm_class.return_value = mock_lcm_instance
+
+ with patch("dimos.protocol.service.lcmservice.autoconf"):
+ service = LCMService()
+ service.start()
+
+ # Create thread pool
+ pool = service._get_call_thread_pool()
+ assert pool is not None
+
+ service.stop()
+
+ # Pool should be cleaned up
+ assert service._call_thread_pool is None
diff --git a/dimos/protocol/service/test_system_configurator.py b/dimos/protocol/service/test_system_configurator.py
new file mode 100644
index 0000000000..22bb662044
--- /dev/null
+++ b/dimos/protocol/service/test_system_configurator.py
@@ -0,0 +1,483 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import resource
+import subprocess
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from dimos.protocol.service.system_configurator import (
+ IDEAL_RMEM_SIZE,
+ BufferConfiguratorLinux,
+ BufferConfiguratorMacOS,
+ MaxFileConfiguratorMacOS,
+ MulticastConfiguratorLinux,
+ MulticastConfiguratorMacOS,
+ SystemConfigurator,
+ _is_root_user,
+ _read_sysctl_int,
+ _write_sysctl_int,
+ configure_system,
+ sudo_run,
+)
+
+# ----------------------------- Helper function tests -----------------------------
+
+
+class TestIsRootUser:
+ def test_is_root_when_euid_is_zero(self) -> None:
+ # Clear the cache before testing
+ _is_root_user.cache_clear()
+ with patch("os.geteuid", return_value=0):
+ assert _is_root_user() is True
+
+ def test_is_not_root_when_euid_is_nonzero(self) -> None:
+ _is_root_user.cache_clear()
+ with patch("os.geteuid", return_value=1000):
+ assert _is_root_user() is False
+
+ def test_returns_false_when_geteuid_not_available(self) -> None:
+ _is_root_user.cache_clear()
+ with patch("os.geteuid", side_effect=AttributeError):
+ assert _is_root_user() is False
+
+
+class TestSudoRun:
+ def test_runs_without_sudo_when_root(self) -> None:
+ _is_root_user.cache_clear()
+ with patch("os.geteuid", return_value=0):
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0)
+ sudo_run("echo", "hello", check=True)
+ mock_run.assert_called_once_with(["echo", "hello"], check=True)
+
+ def test_runs_with_sudo_when_not_root(self) -> None:
+ _is_root_user.cache_clear()
+ with patch("os.geteuid", return_value=1000):
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0)
+ sudo_run("echo", "hello", check=True)
+ mock_run.assert_called_once_with(["sudo", "echo", "hello"], check=True)
+
+
+class TestReadSysctlInt:
+ def test_reads_value_with_equals_sign(self) -> None:
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0, stdout="net.core.rmem_max = 67108864")
+ result = _read_sysctl_int("net.core.rmem_max")
+ assert result == 67108864
+
+ def test_reads_value_with_colon(self) -> None:
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0, stdout="kern.ipc.maxsockbuf: 8388608")
+ result = _read_sysctl_int("kern.ipc.maxsockbuf")
+ assert result == 8388608
+
+ def test_returns_none_on_nonzero_returncode(self) -> None:
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=1, stderr="error")
+ result = _read_sysctl_int("net.core.rmem_max")
+ assert result is None
+
+ def test_returns_none_on_malformed_output(self) -> None:
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0, stdout="invalid output")
+ result = _read_sysctl_int("net.core.rmem_max")
+ assert result is None
+
+ def test_returns_none_on_exception(self) -> None:
+ with patch("subprocess.run", side_effect=Exception("Command failed")):
+ result = _read_sysctl_int("net.core.rmem_max")
+ assert result is None
+
+
+class TestWriteSysctlInt:
+ def test_calls_sudo_run_with_correct_args(self) -> None:
+ _is_root_user.cache_clear()
+ with patch("os.geteuid", return_value=1000):
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0)
+ _write_sysctl_int("net.core.rmem_max", 67108864)
+ mock_run.assert_called_once_with(
+ ["sudo", "sysctl", "-w", "net.core.rmem_max=67108864"],
+ check=True,
+ text=True,
+ capture_output=False,
+ )
+
+
+# ----------------------------- configure_system tests -----------------------------
+
+
+class MockConfigurator(SystemConfigurator):
+ """A mock configurator for testing configure_system."""
+
+ def __init__(self, passes: bool = True, is_critical: bool = False) -> None:
+ self._passes = passes
+ self.critical = is_critical
+ self.fix_called = False
+
+ def check(self) -> bool:
+ return self._passes
+
+ def explanation(self) -> str | None:
+ if self._passes:
+ return None
+ return "Mock explanation"
+
+ def fix(self) -> None:
+ self.fix_called = True
+
+
+class TestConfigureSystem:
+ def test_skips_in_ci_environment(self) -> None:
+ with patch.dict(os.environ, {"CI": "true"}):
+ mock_check = MockConfigurator(passes=False)
+ configure_system([mock_check])
+ assert not mock_check.fix_called
+
+ def test_does_nothing_when_all_checks_pass(self) -> None:
+ with patch.dict(os.environ, {"CI": ""}, clear=False):
+ mock_check = MockConfigurator(passes=True)
+ configure_system([mock_check])
+ assert not mock_check.fix_called
+
+ def test_check_only_mode_does_not_fix(self) -> None:
+ with patch.dict(os.environ, {"CI": ""}, clear=False):
+ mock_check = MockConfigurator(passes=False)
+ configure_system([mock_check], check_only=True)
+ assert not mock_check.fix_called
+
+ def test_prompts_user_and_fixes_on_yes(self) -> None:
+ with patch.dict(os.environ, {"CI": ""}, clear=False):
+ mock_check = MockConfigurator(passes=False)
+ with patch("builtins.input", return_value="y"):
+ configure_system([mock_check])
+ assert mock_check.fix_called
+
+ def test_does_not_fix_on_no(self) -> None:
+ with patch.dict(os.environ, {"CI": ""}, clear=False):
+ mock_check = MockConfigurator(passes=False)
+ with patch("builtins.input", return_value="n"):
+ configure_system([mock_check])
+ assert not mock_check.fix_called
+
+ def test_exits_on_no_with_critical_check(self) -> None:
+ with patch.dict(os.environ, {"CI": ""}, clear=False):
+ mock_check = MockConfigurator(passes=False, is_critical=True)
+ with patch("builtins.input", return_value="n"):
+ with pytest.raises(SystemExit) as exc_info:
+ configure_system([mock_check])
+ assert exc_info.value.code == 1
+
+ def test_handles_eof_error_on_input(self) -> None:
+ with patch.dict(os.environ, {"CI": ""}, clear=False):
+ mock_check = MockConfigurator(passes=False)
+ with patch("builtins.input", side_effect=EOFError):
+ configure_system([mock_check])
+ assert not mock_check.fix_called
+
+
+# ----------------------------- MulticastConfiguratorLinux tests -----------------------------
+
+
+class TestMulticastConfiguratorLinux:
+ def test_check_returns_true_when_fully_configured(self) -> None:
+ configurator = MulticastConfiguratorLinux()
+ with patch("subprocess.run") as mock_run:
+ mock_run.side_effect = [
+ MagicMock(returncode=0), # ip -V
+ MagicMock(
+ returncode=0,
+ stdout="1: lo: mtu 65536",
+ ),
+ MagicMock(returncode=0, stdout="224.0.0.0/4 dev lo scope link"),
+ ]
+ assert configurator.check() is True
+ assert configurator.loopback_ok is True
+ assert configurator.route_ok is True
+
+ def test_check_returns_false_when_multicast_flag_missing(self) -> None:
+ configurator = MulticastConfiguratorLinux()
+ with patch("subprocess.run") as mock_run:
+ mock_run.side_effect = [
+ MagicMock(returncode=0), # ip -V
+ MagicMock(returncode=0, stdout="1: lo: mtu 65536"),
+ MagicMock(returncode=0, stdout="224.0.0.0/4 dev lo scope link"),
+ ]
+ assert configurator.check() is False
+ assert configurator.loopback_ok is False
+ assert configurator.route_ok is True
+
+ def test_check_returns_false_when_route_missing(self) -> None:
+ configurator = MulticastConfiguratorLinux()
+ with patch("subprocess.run") as mock_run:
+ mock_run.side_effect = [
+ MagicMock(returncode=0), # ip -V
+ MagicMock(
+ returncode=0,
+ stdout="1: lo: mtu 65536",
+ ),
+ MagicMock(returncode=0, stdout=""), # Empty - no route
+ ]
+ assert configurator.check() is False
+ assert configurator.loopback_ok is True
+ assert configurator.route_ok is False
+
+ def test_check_returns_false_when_ip_not_found(self) -> None:
+ configurator = MulticastConfiguratorLinux()
+ with patch("subprocess.run", side_effect=FileNotFoundError):
+ assert configurator.check() is False
+ assert configurator.loopback_ok is False
+ assert configurator.route_ok is False
+
+ def test_explanation_includes_needed_commands(self) -> None:
+ configurator = MulticastConfiguratorLinux()
+ configurator.loopback_ok = False
+ configurator.route_ok = False
+ explanation = configurator.explanation()
+ assert "ip link set lo multicast on" in explanation
+ assert "ip route add 224.0.0.0/4 dev lo" in explanation
+
+ def test_fix_runs_needed_commands(self) -> None:
+ _is_root_user.cache_clear()
+ configurator = MulticastConfiguratorLinux()
+ configurator.loopback_ok = False
+ configurator.route_ok = False
+ with patch("os.geteuid", return_value=0):
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0)
+ configurator.fix()
+ assert mock_run.call_count == 2
+
+
+# ----------------------------- MulticastConfiguratorMacOS tests -----------------------------
+
+
+class TestMulticastConfiguratorMacOS:
+ def test_check_returns_true_when_route_exists(self) -> None:
+ configurator = MulticastConfiguratorMacOS()
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(
+ returncode=0,
+ stdout="224.0.0.0/4 link#1 UCS lo0",
+ )
+ assert configurator.check() is True
+
+ def test_check_returns_false_when_route_missing(self) -> None:
+ configurator = MulticastConfiguratorMacOS()
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(
+ returncode=0, stdout="default 192.168.1.1 UGScg en0"
+ )
+ assert configurator.check() is False
+
+ def test_check_returns_false_on_netstat_error(self) -> None:
+ configurator = MulticastConfiguratorMacOS()
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=1, stderr="error")
+ assert configurator.check() is False
+
+ def test_explanation_includes_route_command(self) -> None:
+ configurator = MulticastConfiguratorMacOS()
+ explanation = configurator.explanation()
+ assert "route add -net 224.0.0.0/4 -interface lo0" in explanation
+
+ def test_fix_runs_route_command(self) -> None:
+ _is_root_user.cache_clear()
+ configurator = MulticastConfiguratorMacOS()
+ with patch("os.geteuid", return_value=0):
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0)
+ configurator.fix()
+ mock_run.assert_called_once()
+ args = mock_run.call_args[0][0]
+ assert "route" in args
+ assert "224.0.0.0/4" in args
+
+
+# ----------------------------- BufferConfiguratorLinux tests -----------------------------
+
+
+class TestBufferConfiguratorLinux:
+ def test_check_returns_true_when_buffers_sufficient(self) -> None:
+ configurator = BufferConfiguratorLinux()
+ with patch("dimos.protocol.service.system_configurator._read_sysctl_int") as mock_read:
+ mock_read.return_value = IDEAL_RMEM_SIZE
+ assert configurator.check() is True
+ assert configurator.needs == []
+
+ def test_check_returns_false_when_rmem_max_low(self) -> None:
+ configurator = BufferConfiguratorLinux()
+ with patch("dimos.protocol.service.system_configurator._read_sysctl_int") as mock_read:
+ mock_read.side_effect = [1048576, IDEAL_RMEM_SIZE] # rmem_max low
+ assert configurator.check() is False
+ assert len(configurator.needs) == 1
+ assert configurator.needs[0][0] == "net.core.rmem_max"
+
+ def test_check_returns_false_when_both_low(self) -> None:
+ configurator = BufferConfiguratorLinux()
+ with patch("dimos.protocol.service.system_configurator._read_sysctl_int") as mock_read:
+ mock_read.return_value = 1048576 # Both low
+ assert configurator.check() is False
+ assert len(configurator.needs) == 2
+
+ def test_explanation_lists_needed_changes(self) -> None:
+ configurator = BufferConfiguratorLinux()
+ configurator.needs = [("net.core.rmem_max", IDEAL_RMEM_SIZE)]
+ explanation = configurator.explanation()
+ assert "net.core.rmem_max" in explanation
+ assert str(IDEAL_RMEM_SIZE) in explanation
+
+ def test_fix_writes_needed_values(self) -> None:
+ configurator = BufferConfiguratorLinux()
+ configurator.needs = [("net.core.rmem_max", IDEAL_RMEM_SIZE)]
+ with patch("dimos.protocol.service.system_configurator._write_sysctl_int") as mock_write:
+ configurator.fix()
+ mock_write.assert_called_once_with("net.core.rmem_max", IDEAL_RMEM_SIZE)
+
+
+# ----------------------------- BufferConfiguratorMacOS tests -----------------------------
+
+
+class TestBufferConfiguratorMacOS:
+ def test_check_returns_true_when_buffers_sufficient(self) -> None:
+ configurator = BufferConfiguratorMacOS()
+ with patch("dimos.protocol.service.system_configurator._read_sysctl_int") as mock_read:
+ mock_read.side_effect = [
+ BufferConfiguratorMacOS.TARGET_BUFFER_SIZE,
+ BufferConfiguratorMacOS.TARGET_RECVSPACE,
+ BufferConfiguratorMacOS.TARGET_DGRAM_SIZE,
+ ]
+ assert configurator.check() is True
+ assert configurator.needs == []
+
+ def test_check_returns_false_when_values_low(self) -> None:
+ configurator = BufferConfiguratorMacOS()
+ with patch("dimos.protocol.service.system_configurator._read_sysctl_int") as mock_read:
+ mock_read.return_value = 1024 # All low
+ assert configurator.check() is False
+ assert len(configurator.needs) == 3
+
+ def test_explanation_lists_needed_changes(self) -> None:
+ configurator = BufferConfiguratorMacOS()
+ configurator.needs = [
+ ("kern.ipc.maxsockbuf", BufferConfiguratorMacOS.TARGET_BUFFER_SIZE),
+ ]
+ explanation = configurator.explanation()
+ assert "kern.ipc.maxsockbuf" in explanation
+
+ def test_fix_writes_needed_values(self) -> None:
+ configurator = BufferConfiguratorMacOS()
+ configurator.needs = [
+ ("kern.ipc.maxsockbuf", BufferConfiguratorMacOS.TARGET_BUFFER_SIZE),
+ ]
+ with patch("dimos.protocol.service.system_configurator._write_sysctl_int") as mock_write:
+ configurator.fix()
+ mock_write.assert_called_once_with(
+ "kern.ipc.maxsockbuf", BufferConfiguratorMacOS.TARGET_BUFFER_SIZE
+ )
+
+
+# ----------------------------- MaxFileConfiguratorMacOS tests -----------------------------
+
+
+class TestMaxFileConfiguratorMacOS:
+ def test_check_returns_true_when_soft_limit_sufficient(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ with patch("resource.getrlimit") as mock_getrlimit:
+ mock_getrlimit.return_value = (65536, 1048576)
+ assert configurator.check() is True
+ assert configurator.current_soft == 65536
+ assert configurator.current_hard == 1048576
+
+ def test_check_returns_false_when_soft_limit_low(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ with patch("resource.getrlimit") as mock_getrlimit:
+ mock_getrlimit.return_value = (256, 1048576)
+ assert configurator.check() is False
+ assert configurator.can_fix_without_sudo is True
+
+ def test_check_returns_false_when_both_limits_low(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ with patch("resource.getrlimit") as mock_getrlimit:
+ mock_getrlimit.return_value = (256, 10240)
+ assert configurator.check() is False
+ assert configurator.can_fix_without_sudo is False
+
+ def test_check_returns_false_on_exception(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ with patch("resource.getrlimit", side_effect=Exception("error")):
+ assert configurator.check() is False
+
+ def test_explanation_when_sudo_not_needed(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ configurator.current_soft = 256
+ configurator.current_hard = 1048576
+ configurator.can_fix_without_sudo = True
+ explanation = configurator.explanation()
+ assert "65536" in explanation
+ assert "no sudo" in explanation.lower() or "Raise soft" in explanation
+
+ def test_explanation_when_sudo_needed(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ configurator.current_soft = 256
+ configurator.current_hard = 10240
+ configurator.can_fix_without_sudo = False
+ explanation = configurator.explanation()
+ assert "launchctl" in explanation
+
+ def test_fix_raises_soft_limit_without_sudo(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ configurator.current_soft = 256
+ configurator.current_hard = 1048576
+ configurator.can_fix_without_sudo = True
+ with patch("resource.setrlimit") as mock_setrlimit:
+ configurator.fix()
+ mock_setrlimit.assert_called_once_with(resource.RLIMIT_NOFILE, (65536, 1048576))
+
+ def test_fix_does_nothing_when_already_sufficient(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ configurator.current_soft = 65536
+ configurator.current_hard = 1048576
+ with patch("resource.setrlimit") as mock_setrlimit:
+ configurator.fix()
+ mock_setrlimit.assert_not_called()
+
+ def test_fix_uses_launchctl_when_hard_limit_low(self) -> None:
+ _is_root_user.cache_clear()
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ configurator.current_soft = 256
+ configurator.current_hard = 10240
+ configurator.can_fix_without_sudo = False
+ with patch("os.geteuid", return_value=0):
+ with patch("subprocess.run") as mock_run:
+ mock_run.return_value = MagicMock(returncode=0)
+ with patch("resource.setrlimit"):
+ configurator.fix()
+ # Check launchctl was called
+ args = mock_run.call_args[0][0]
+ assert "launchctl" in args
+ assert "maxfiles" in args
+
+ def test_fix_raises_on_setrlimit_error(self) -> None:
+ configurator = MaxFileConfiguratorMacOS(target=65536)
+ configurator.current_soft = 256
+ configurator.current_hard = 1048576
+ configurator.can_fix_without_sudo = True
+ with patch("resource.setrlimit", side_effect=ValueError("test error")):
+ with pytest.raises(ValueError):
+ configurator.fix()
diff --git a/dimos/protocol/skill/test_coordinator.py b/dimos/protocol/skill/test_coordinator.py
index acaad98dda..bd00ea69c2 100644
--- a/dimos/protocol/skill/test_coordinator.py
+++ b/dimos/protocol/skill/test_coordinator.py
@@ -96,6 +96,7 @@ def take_photo(self) -> Image:
return img
+@pytest.mark.integration
@pytest.mark.asyncio # type: ignore[untyped-decorator]
async def test_coordinator_parallel_calls() -> None:
container = SkillContainerTest()
@@ -136,6 +137,7 @@ async def test_coordinator_parallel_calls() -> None:
skillCoordinator.stop()
+@pytest.mark.integration
@pytest.mark.asyncio # type: ignore[untyped-decorator]
async def test_coordinator_generator() -> None:
container = SkillContainerTest()
diff --git a/dimos/robot/agilex/README.md b/dimos/robot/agilex/README.md
deleted file mode 100644
index 8342a6045e..0000000000
--- a/dimos/robot/agilex/README.md
+++ /dev/null
@@ -1,371 +0,0 @@
-# DIMOS Manipulator Robot Development Guide
-
-This guide explains how to create robot classes, integrate agents, and use the DIMOS module system with LCM transport.
-
-## Table of Contents
-1. [Robot Class Architecture](#robot-class-architecture)
-2. [Module System & LCM Transport](#module-system--lcm-transport)
-3. [Agent Integration](#agent-integration)
-4. [Complete Example](#complete-example)
-
-## Robot Class Architecture
-
-### Basic Robot Class Structure
-
-A DIMOS robot class should follow this pattern:
-
-```python
-from typing import Optional, List
-from dimos import core
-from dimos.types.robot_capabilities import RobotCapability
-
-class YourRobot:
- """Your robot implementation."""
-
- def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None):
- # Core components
- self.dimos = None
- self.modules = {}
- self.skill_library = SkillLibrary()
-
- # Define capabilities
- self.capabilities = robot_capabilities or [
- RobotCapability.VISION,
- RobotCapability.MANIPULATION,
- ]
-
- async def start(self):
- """Start the robot modules."""
- # Initialize DIMOS with worker count
- self.dimos = core.start(2) # Number of workers needed
-
- # Deploy modules
- # ... (see Module System section)
-
- def stop(self):
- """Stop all modules and clean up."""
- # Stop modules
- # Close DIMOS
- if self.dimos:
- self.dimos.close()
-```
-
-### Key Components Explained
-
-1. **Initialization**: Store references to modules, skills, and capabilities
-2. **Async Start**: Modules must be deployed asynchronously
-3. **Proper Cleanup**: Always stop modules before closing DIMOS
-
-## Module System & LCM Transport
-
-### Understanding DIMOS Modules
-
-Modules are the building blocks of DIMOS robots. They:
-- Process data streams (inputs)
-- Produce outputs
-- Can be connected together
-- Communicate via LCM (Lightweight Communications and Marshalling)
-
-### Deploying a Module
-
-```python
-# Deploy a camera module
-self.camera = self.dimos.deploy(
- ZEDModule, # Module class
- camera_id=0, # Module parameters
- resolution="HD720",
- depth_mode="NEURAL",
- fps=30,
- publish_rate=30.0,
- frame_id="camera_frame"
-)
-```
-
-### Setting Up LCM Transport
-
-LCM transport enables inter-module communication:
-
-```python
-# Enable LCM auto-configuration
-from dimos.protocol import pubsub
-pubsub.lcm.autoconf()
-
-# Configure output transport
-self.camera.color_image.transport = core.LCMTransport(
- "/camera/color_image", # Topic name
- Image # Message type
-)
-self.camera.depth_image.transport = core.LCMTransport(
- "/camera/depth_image",
- Image
-)
-```
-
-### Connecting Modules
-
-Connect module outputs to inputs:
-
-```python
-# Connect manipulation module to camera outputs
-self.manipulation.rgb_image.connect(self.camera.color_image)
-self.manipulation.depth_image.connect(self.camera.depth_image)
-self.manipulation.camera_info.connect(self.camera.camera_info)
-```
-
-### Module Communication Pattern
-
-```
-┌──────────────┐ LCM ┌────────────────┐ LCM ┌──────────────┐
-│ Camera │────────▶│ Manipulation │────────▶│ Visualization│
-│ Module │ Messages│ Module │ Messages│ Output │
-└──────────────┘ └────────────────┘ └──────────────┘
- ▲ ▲
- │ │
- └──────────────────────────┘
- Direct Connection via RPC call
-```
-
-## Agent Integration
-
-### Setting Up Agent with Robot
-
-The run file pattern for agent integration:
-
-```python
-#!/usr/bin/env python3
-import asyncio
-import reactivex as rx
-from dimos.agents_deprecated.claude_agent import ClaudeAgent
-from dimos.web.robot_web_interface import RobotWebInterface
-
-def main():
- # 1. Create and start robot
- robot = YourRobot()
- asyncio.run(robot.start())
-
- # 2. Set up skills
- skills = robot.get_skills()
- skills.add(YourSkill)
- skills.create_instance("YourSkill", robot=robot)
-
- # 3. Set up reactive streams
- agent_response_subject = rx.subject.Subject()
- agent_response_stream = agent_response_subject.pipe(ops.share())
-
- # 4. Create web interface
- web_interface = RobotWebInterface(
- port=5555,
- text_streams={"agent_responses": agent_response_stream},
- audio_subject=rx.subject.Subject()
- )
-
- # 5. Create agent
- agent = ClaudeAgent(
- dev_name="your_agent",
- input_query_stream=web_interface.query_stream,
- skills=skills,
- system_query="Your system prompt here",
- model_name="claude-3-5-haiku-latest"
- )
-
- # 6. Connect agent responses
- agent.get_response_observable().subscribe(
- lambda x: agent_response_subject.on_next(x)
- )
-
- # 7. Run interface
- web_interface.run()
-```
-
-### Key Integration Points
-
-1. **Reactive Streams**: Use RxPy for event-driven communication
-2. **Web Interface**: Provides user input/output
-3. **Agent**: Processes natural language and executes skills
-4. **Skills**: Define robot capabilities as executable actions
-
-## Complete Example
-
-### Step 1: Create Robot Class (`my_robot.py`)
-
-```python
-import asyncio
-from typing import Optional, List
-from dimos import core
-from dimos.hardware.camera import CameraModule
-from dimos.manipulation.module import ManipulationModule
-from dimos.skills.skills import SkillLibrary
-from dimos.types.robot_capabilities import RobotCapability
-from dimos_lcm.sensor_msgs import Image, CameraInfo
-from dimos.protocol import pubsub
-
-class MyRobot:
- def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None):
- self.dimos = None
- self.camera = None
- self.manipulation = None
- self.skill_library = SkillLibrary()
-
- self.capabilities = robot_capabilities or [
- RobotCapability.VISION,
- RobotCapability.MANIPULATION,
- ]
-
- async def start(self):
- # Start DIMOS
- self.dimos = core.start(2)
-
- # Enable LCM
- pubsub.lcm.autoconf()
-
- # Deploy camera
- self.camera = self.dimos.deploy(
- CameraModule,
- camera_id=0,
- fps=30
- )
-
- # Configure camera LCM
- self.camera.color_image.transport = core.LCMTransport("/camera/rgb", Image)
- self.camera.depth_image.transport = core.LCMTransport("/camera/depth", Image)
- self.camera.camera_info.transport = core.LCMTransport("/camera/info", CameraInfo)
-
- # Deploy manipulation
- self.manipulation = self.dimos.deploy(ManipulationModule)
-
- # Connect modules
- self.manipulation.rgb_image.connect(self.camera.color_image)
- self.manipulation.depth_image.connect(self.camera.depth_image)
- self.manipulation.camera_info.connect(self.camera.camera_info)
-
- # Configure manipulation output
- self.manipulation.viz_image.transport = core.LCMTransport("/viz/output", Image)
-
- # Start modules
- self.camera.start()
- self.manipulation.start()
-
- await asyncio.sleep(2) # Allow initialization
-
- def get_skills(self):
- return self.skill_library
-
- def stop(self):
- if self.manipulation:
- self.manipulation.stop()
- if self.camera:
- self.camera.stop()
- if self.dimos:
- self.dimos.close()
-```
-
-### Step 2: Create Run Script (`run.py`)
-
-```python
-#!/usr/bin/env python3
-import asyncio
-import os
-from my_robot import MyRobot
-from dimos.agents_deprecated.claude_agent import ClaudeAgent
-from dimos.skills.basic import BasicSkill
-from dimos.web.robot_web_interface import RobotWebInterface
-import reactivex as rx
-import reactivex.operators as ops
-
-SYSTEM_PROMPT = """You are a helpful robot assistant."""
-
-def main():
- # Check API key
- if not os.getenv("ANTHROPIC_API_KEY"):
- print("Please set ANTHROPIC_API_KEY")
- return
-
- # Create robot
- robot = MyRobot()
-
- try:
- # Start robot
- asyncio.run(robot.start())
-
- # Set up skills
- skills = robot.get_skills()
- skills.add(BasicSkill)
- skills.create_instance("BasicSkill", robot=robot)
-
- # Set up streams
- agent_response_subject = rx.subject.Subject()
- agent_response_stream = agent_response_subject.pipe(ops.share())
-
- # Create web interface
- web_interface = RobotWebInterface(
- port=5555,
- text_streams={"agent_responses": agent_response_stream}
- )
-
- # Create agent
- agent = ClaudeAgent(
- dev_name="my_agent",
- input_query_stream=web_interface.query_stream,
- skills=skills,
- system_query=SYSTEM_PROMPT
- )
-
- # Connect responses
- agent.get_response_observable().subscribe(
- lambda x: agent_response_subject.on_next(x)
- )
-
- print("Robot ready at http://localhost:5555")
-
- # Run
- web_interface.run()
-
- finally:
- robot.stop()
-
-if __name__ == "__main__":
- main()
-```
-
-### Step 3: Define Skills (`skills.py`)
-
-```python
-from dimos.skills import Skill, skill
-
-@skill(
- description="Perform a basic action",
- parameters={
- "action": "The action to perform"
- }
-)
-class BasicSkill(Skill):
- def __init__(self, robot):
- self.robot = robot
-
- def run(self, action: str):
- # Implement skill logic
- return f"Performed: {action}"
-```
-
-## Best Practices
-
-1. **Module Lifecycle**: Always start DIMOS before deploying modules
-2. **LCM Topics**: Use descriptive topic names with namespaces
-3. **Error Handling**: Wrap module operations in try-except blocks
-4. **Resource Cleanup**: Ensure proper cleanup in stop() methods
-5. **Async Operations**: Use asyncio for non-blocking operations
-6. **Stream Management**: Use RxPy for reactive programming patterns
-
-## Debugging Tips
-
-1. **Check Module Status**: Print module.io().result() to see connections
-2. **Monitor LCM**: Use Foxglove to visualize LCM messages
-3. **Log Everything**: Use dimos.utils.logging_config.setup_logger()
-4. **Test Modules Independently**: Deploy and test one module at a time
-
-## Common Issues
-
-1. **"Module not started"**: Ensure start() is called after deployment
-2. **"No data received"**: Check LCM transport configuration
-3. **"Connection failed"**: Verify input/output types match
-4. **"Cleanup errors"**: Stop modules before closing DIMOS
diff --git a/dimos/robot/agilex/README_CN.md b/dimos/robot/agilex/README_CN.md
deleted file mode 100644
index a8d79ebec1..0000000000
--- a/dimos/robot/agilex/README_CN.md
+++ /dev/null
@@ -1,465 +0,0 @@
-# DIMOS 机械臂机器人开发指南
-
-本指南介绍如何创建机器人类、集成智能体(Agent)以及使用 DIMOS 模块系统和 LCM 传输。
-
-## 目录
-1. [机器人类架构](#机器人类架构)
-2. [模块系统与 LCM 传输](#模块系统与-lcm-传输)
-3. [智能体集成](#智能体集成)
-4. [完整示例](#完整示例)
-
-## 机器人类架构
-
-### 基本机器人类结构
-
-DIMOS 机器人类应遵循以下模式:
-
-```python
-from typing import Optional, List
-from dimos import core
-from dimos.types.robot_capabilities import RobotCapability
-
-class YourRobot:
- """您的机器人实现。"""
-
- def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None):
- # 核心组件
- self.dimos = None
- self.modules = {}
- self.skill_library = SkillLibrary()
-
- # 定义能力
- self.capabilities = robot_capabilities or [
- RobotCapability.VISION,
- RobotCapability.MANIPULATION,
- ]
-
- async def start(self):
- """启动机器人模块。"""
- # 初始化 DIMOS,指定工作线程数
- self.dimos = core.start(2) # 需要的工作线程数
-
- # 部署模块
- # ... (参见模块系统章节)
-
- def stop(self):
- """停止所有模块并清理资源。"""
- # 停止模块
- # 关闭 DIMOS
- if self.dimos:
- self.dimos.close()
-```
-
-### 关键组件说明
-
-1. **初始化**:存储模块、技能和能力的引用
-2. **异步启动**:模块必须异步部署
-3. **正确清理**:在关闭 DIMOS 之前始终停止模块
-
-## 模块系统与 LCM 传输
-
-### 理解 DIMOS 模块
-
-模块是 DIMOS 机器人的构建块。它们:
-- 处理数据流(输入)
-- 产生输出
-- 可以相互连接
-- 通过 LCM(轻量级通信和编组)进行通信
-
-### 部署模块
-
-```python
-# 部署相机模块
-self.camera = self.dimos.deploy(
- ZEDModule, # 模块类
- camera_id=0, # 模块参数
- resolution="HD720",
- depth_mode="NEURAL",
- fps=30,
- publish_rate=30.0,
- frame_id="camera_frame"
-)
-```
-
-### 设置 LCM 传输
-
-LCM 传输实现模块间通信:
-
-```python
-# 启用 LCM 自动配置
-from dimos.protocol import pubsub
-pubsub.lcm.autoconf()
-
-# 配置输出传输
-self.camera.color_image.transport = core.LCMTransport(
- "/camera/color_image", # 主题名称
- Image # 消息类型
-)
-self.camera.depth_image.transport = core.LCMTransport(
- "/camera/depth_image",
- Image
-)
-```
-
-### 连接模块
-
-将模块输出连接到输入:
-
-```python
-# 将操作模块连接到相机输出
-self.manipulation.rgb_image.connect(self.camera.color_image) # ROS set_callback
-self.manipulation.depth_image.connect(self.camera.depth_image)
-self.manipulation.camera_info.connect(self.camera.camera_info)
-```
-
-### 模块通信模式
-
-```
-┌──────────────┐ LCM ┌────────────────┐ LCM ┌──────────────┐
-│ 相机模块 │────────▶│ 操作模块 │────────▶│ 可视化输出 │
-│ │ 消息 │ │ 消息 │ │
-└──────────────┘ └────────────────┘ └──────────────┘
- ▲ ▲
- │ │
- └──────────────────────────┘
- 直接连接(RPC指令)
-```
-
-## 智能体集成
-
-### 设置智能体与机器人
-
-运行文件的智能体集成模式:
-
-```python
-#!/usr/bin/env python3
-import asyncio
-import reactivex as rx
-from dimos.agents_deprecated.claude_agent import ClaudeAgent
-from dimos.web.robot_web_interface import RobotWebInterface
-
-def main():
- # 1. 创建并启动机器人
- robot = YourRobot()
- asyncio.run(robot.start())
-
- # 2. 设置技能
- skills = robot.get_skills()
- skills.add(YourSkill)
- skills.create_instance("YourSkill", robot=robot)
-
- # 3. 设置响应式流
- agent_response_subject = rx.subject.Subject()
- agent_response_stream = agent_response_subject.pipe(ops.share())
-
- # 4. 创建 Web 界面
- web_interface = RobotWebInterface(
- port=5555,
- text_streams={"agent_responses": agent_response_stream},
- audio_subject=rx.subject.Subject()
- )
-
- # 5. 创建智能体
- agent = ClaudeAgent(
- dev_name="your_agent",
- input_query_stream=web_interface.query_stream,
- skills=skills,
- system_query="您的系统提示词",
- model_name="claude-3-5-haiku-latest"
- )
-
- # 6. 连接智能体响应
- agent.get_response_observable().subscribe(
- lambda x: agent_response_subject.on_next(x)
- )
-
- # 7. 运行界面
- web_interface.run()
-```
-
-### 关键集成点
-
-1. **响应式流**:使用 RxPy 进行事件驱动通信
-2. **Web 界面**:提供用户输入/输出
-3. **智能体**:处理自然语言并执行技能
-4. **技能**:将机器人能力定义为可执行动作
-
-## 完整示例
-
-### 步骤 1:创建机器人类(`my_robot.py`)
-
-```python
-import asyncio
-from typing import Optional, List
-from dimos import core
-from dimos.hardware.camera import CameraModule
-from dimos.manipulation.module import ManipulationModule
-from dimos.skills.skills import SkillLibrary
-from dimos.types.robot_capabilities import RobotCapability
-from dimos_lcm.sensor_msgs import Image, CameraInfo
-from dimos.protocol import pubsub
-
-class MyRobot:
- def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None):
- self.dimos = None
- self.camera = None
- self.manipulation = None
- self.skill_library = SkillLibrary()
-
- self.capabilities = robot_capabilities or [
- RobotCapability.VISION,
- RobotCapability.MANIPULATION,
- ]
-
- async def start(self):
- # 启动 DIMOS
- self.dimos = core.start(2)
-
- # 启用 LCM
- pubsub.lcm.autoconf()
-
- # 部署相机
- self.camera = self.dimos.deploy(
- CameraModule,
- camera_id=0,
- fps=30
- )
-
- # 配置相机 LCM
- self.camera.color_image.transport = core.LCMTransport("/camera/rgb", Image)
- self.camera.depth_image.transport = core.LCMTransport("/camera/depth", Image)
- self.camera.camera_info.transport = core.LCMTransport("/camera/info", CameraInfo)
-
- # 部署操作模块
- self.manipulation = self.dimos.deploy(ManipulationModule)
-
- # 连接模块
- self.manipulation.rgb_image.connect(self.camera.color_image)
- self.manipulation.depth_image.connect(self.camera.depth_image)
- self.manipulation.camera_info.connect(self.camera.camera_info)
-
- # 配置操作输出
- self.manipulation.viz_image.transport = core.LCMTransport("/viz/output", Image)
-
- # 启动模块
- self.camera.start()
- self.manipulation.start()
-
- await asyncio.sleep(2) # 允许初始化
-
- def get_skills(self):
- return self.skill_library
-
- def stop(self):
- if self.manipulation:
- self.manipulation.stop()
- if self.camera:
- self.camera.stop()
- if self.dimos:
- self.dimos.close()
-```
-
-### 步骤 2:创建运行脚本(`run.py`)
-
-```python
-#!/usr/bin/env python3
-import asyncio
-import os
-from my_robot import MyRobot
-from dimos.agents_deprecated.claude_agent import ClaudeAgent
-from dimos.skills.basic import BasicSkill
-from dimos.web.robot_web_interface import RobotWebInterface
-import reactivex as rx
-import reactivex.operators as ops
-
-SYSTEM_PROMPT = """您是一个有用的机器人助手。"""
-
-def main():
- # 检查 API 密钥
- if not os.getenv("ANTHROPIC_API_KEY"):
- print("请设置 ANTHROPIC_API_KEY")
- return
-
- # 创建机器人
- robot = MyRobot()
-
- try:
- # 启动机器人
- asyncio.run(robot.start())
-
- # 设置技能
- skills = robot.get_skills()
- skills.add(BasicSkill)
- skills.create_instance("BasicSkill", robot=robot)
-
- # 设置流
- agent_response_subject = rx.subject.Subject()
- agent_response_stream = agent_response_subject.pipe(ops.share())
-
- # 创建 Web 界面
- web_interface = RobotWebInterface(
- port=5555,
- text_streams={"agent_responses": agent_response_stream}
- )
-
- # 创建智能体
- agent = ClaudeAgent(
- dev_name="my_agent",
- input_query_stream=web_interface.query_stream,
- skills=skills,
- system_query=SYSTEM_PROMPT
- )
-
- # 连接响应
- agent.get_response_observable().subscribe(
- lambda x: agent_response_subject.on_next(x)
- )
-
- print("机器人就绪,访问 http://localhost:5555")
-
- # 运行
- web_interface.run()
-
- finally:
- robot.stop()
-
-if __name__ == "__main__":
- main()
-```
-
-### 步骤 3:定义技能(`skills.py`)
-
-```python
-from dimos.skills import Skill, skill
-
-@skill(
- description="执行一个基本动作",
- parameters={
- "action": "要执行的动作"
- }
-)
-class BasicSkill(Skill):
- def __init__(self, robot):
- self.robot = robot
-
- def run(self, action: str):
- # 实现技能逻辑
- return f"已执行:{action}"
-```
-
-## 最佳实践
-
-1. **模块生命周期**:在部署模块之前始终先启动 DIMOS
-2. **LCM 主题**:使用带命名空间的描述性主题名称
-3. **错误处理**:用 try-except 块包装模块操作
-4. **资源清理**:确保在 stop() 方法中正确清理
-5. **异步操作**:使用 asyncio 进行非阻塞操作
-6. **流管理**:使用 RxPy 进行响应式编程模式
-
-## 调试技巧
-
-1. **检查模块状态**:打印 module.io().result() 查看连接
-2. **监控 LCM**:使用 Foxglove 可视化 LCM 消息
-3. **记录一切**:使用 dimos.utils.logging_config.setup_logger()
-4. **独立测试模块**:一次部署和测试一个模块
-
-## 常见问题
-
-1. **"模块未启动"**:确保在部署后调用 start()
-2. **"未收到数据"**:检查 LCM 传输配置
-3. **"连接失败"**:验证输入/输出类型是否匹配
-4. **"清理错误"**:在关闭 DIMOS 之前停止模块
-
-## 高级主题
-
-### 自定义模块开发
-
-创建自定义模块的基本结构:
-
-```python
-from dimos.core import Module, In, Out, rpc
-
-class CustomModule(Module):
- # 定义输入
- input_data: In[DataType]
-
- # 定义输出
- output_data: Out[DataType]
-
- def __init__(self, param1, param2, **kwargs):
- super().__init__(**kwargs)
- self.param1 = param1
- self.param2 = param2
-
- @rpc
- def start(self):
- """启动模块处理。"""
- self.input_data.subscribe(self._process_data)
-
- def _process_data(self, data):
- """处理输入数据。"""
- # 处理逻辑
- result = self.process(data)
- # 发布输出
- self.output_data.publish(result)
-
- @rpc
- def stop(self):
- """停止模块。"""
- # 清理资源
- pass
-```
-
-### 技能开发指南
-
-技能是机器人可执行的高级动作:
-
-```python
-from dimos.skills import Skill, skill
-from typing import Optional
-
-@skill(
- description="复杂操作技能",
- parameters={
- "target": "目标对象",
- "location": "目标位置"
- }
-)
-class ComplexSkill(Skill):
- def __init__(self, robot, **kwargs):
- super().__init__(**kwargs)
- self.robot = robot
-
- def run(self, target: str, location: Optional[str] = None):
- """执行技能逻辑。"""
- try:
- # 1. 感知阶段
- object_info = self.robot.detect_object(target)
-
- # 2. 规划阶段
- if location:
- plan = self.robot.plan_movement(object_info, location)
-
- # 3. 执行阶段
- result = self.robot.execute_plan(plan)
-
- return {
- "success": True,
- "message": f"成功移动 {target} 到 {location}"
- }
-
- except Exception as e:
- return {
- "success": False,
- "error": str(e)
- }
-```
-
-### 性能优化
-
-1. **并行处理**:使用多个工作线程处理不同模块
-2. **数据缓冲**:为高频数据流实现缓冲机制
-3. **延迟加载**:仅在需要时初始化重型模块
-4. **资源池化**:重用昂贵的资源(如神经网络模型)
-
-希望本指南能帮助您快速上手 DIMOS 机器人开发!
diff --git a/dimos/robot/agilex/piper_arm.py b/dimos/robot/agilex/piper_arm.py
deleted file mode 100644
index 29624b9a4c..0000000000
--- a/dimos/robot/agilex/piper_arm.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-
-# Import LCM message types
-from dimos_lcm.sensor_msgs import CameraInfo
-
-from dimos import core
-from dimos.hardware.sensors.camera.zed import ZEDModule
-from dimos.manipulation.visual_servoing.manipulation_module import ManipulationModule
-from dimos.msgs.sensor_msgs import Image
-from dimos.protocol import pubsub
-from dimos.robot.foxglove_bridge import FoxgloveBridge
-from dimos.robot.robot import Robot
-from dimos.skills.skills import SkillLibrary
-from dimos.types.robot_capabilities import RobotCapability
-from dimos.utils.logging_config import setup_logger
-
-logger = setup_logger()
-
-
-class PiperArmRobot(Robot):
- """Piper Arm robot with ZED camera and manipulation capabilities."""
-
- def __init__(self, robot_capabilities: list[RobotCapability] | None = None) -> None:
- super().__init__()
- self.dimos = None
- self.stereo_camera = None
- self.manipulation_interface = None
- self.skill_library = SkillLibrary() # type: ignore[assignment]
-
- # Initialize capabilities
- self.capabilities = robot_capabilities or [
- RobotCapability.VISION,
- RobotCapability.MANIPULATION,
- ]
-
- async def start(self) -> None:
- """Start the robot modules."""
- # Start Dimos
- self.dimos = core.start(2) # type: ignore[assignment] # Need 2 workers for ZED and manipulation modules
- self.foxglove_bridge = FoxgloveBridge()
-
- # Enable LCM auto-configuration
- pubsub.lcm.autoconf() # type: ignore[attr-defined]
-
- # Deploy ZED module
- logger.info("Deploying ZED module...")
- self.stereo_camera = self.dimos.deploy( # type: ignore[attr-defined]
- ZEDModule,
- camera_id=0,
- resolution="HD720",
- depth_mode="NEURAL",
- fps=30,
- enable_tracking=False, # We don't need tracking for manipulation
- publish_rate=30.0,
- frame_id="zed_camera",
- )
-
- # Configure ZED LCM transports
- self.stereo_camera.color_image.transport = core.LCMTransport("/zed/color_image", Image) # type: ignore[attr-defined]
- self.stereo_camera.depth_image.transport = core.LCMTransport("/zed/depth_image", Image) # type: ignore[attr-defined]
- self.stereo_camera.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) # type: ignore[attr-defined]
-
- # Deploy manipulation module
- logger.info("Deploying manipulation module...")
- self.manipulation_interface = self.dimos.deploy(ManipulationModule) # type: ignore[attr-defined]
-
- # Connect manipulation inputs to ZED outputs
- self.manipulation_interface.rgb_image.connect(self.stereo_camera.color_image) # type: ignore[attr-defined]
- self.manipulation_interface.depth_image.connect(self.stereo_camera.depth_image) # type: ignore[attr-defined]
- self.manipulation_interface.camera_info.connect(self.stereo_camera.camera_info) # type: ignore[attr-defined]
-
- # Configure manipulation output
- self.manipulation_interface.viz_image.transport = core.LCMTransport( # type: ignore[attr-defined]
- "/manipulation/viz", Image
- )
-
- # Print module info
- logger.info("Modules configured:")
- print("\nZED Module:")
- print(self.stereo_camera.io()) # type: ignore[attr-defined]
- print("\nManipulation Module:")
- print(self.manipulation_interface.io()) # type: ignore[attr-defined]
-
- # Start modules
- logger.info("Starting modules...")
- self.foxglove_bridge.start()
- self.stereo_camera.start() # type: ignore[attr-defined]
- self.manipulation_interface.start() # type: ignore[attr-defined]
-
- # Give modules time to initialize
- await asyncio.sleep(2)
-
- logger.info("PiperArmRobot initialized and started")
-
- def pick_and_place( # type: ignore[no-untyped-def]
- self, pick_x: int, pick_y: int, place_x: int | None = None, place_y: int | None = None
- ):
- """Execute pick and place task.
-
- Args:
- pick_x: X coordinate for pick location
- pick_y: Y coordinate for pick location
- place_x: X coordinate for place location (optional)
- place_y: Y coordinate for place location (optional)
-
- Returns:
- Result of the pick and place operation
- """
- if self.manipulation_interface:
- return self.manipulation_interface.pick_and_place(pick_x, pick_y, place_x, place_y)
- else:
- logger.error("Manipulation module not initialized")
- return False
-
- def handle_keyboard_command(self, key: str): # type: ignore[no-untyped-def]
- """Pass keyboard commands to manipulation module.
-
- Args:
- key: Keyboard key pressed
-
- Returns:
- Action taken or None
- """
- if self.manipulation_interface:
- return self.manipulation_interface.handle_keyboard_command(key)
- else:
- logger.error("Manipulation module not initialized")
- return None
-
- def stop(self) -> None:
- """Stop all modules and clean up."""
- logger.info("Stopping PiperArmRobot...")
-
- try:
- if self.manipulation_interface:
- self.manipulation_interface.stop()
-
- if self.stereo_camera:
- self.stereo_camera.stop()
- except Exception as e:
- logger.warning(f"Error stopping modules: {e}")
-
- # Close dimos last to ensure workers are available for cleanup
- if self.dimos:
- self.dimos.close()
-
- logger.info("PiperArmRobot stopped")
-
-
-async def run_piper_arm() -> None:
- """Run the Piper Arm robot."""
- robot = PiperArmRobot() # type: ignore[abstract]
-
- await robot.start()
-
- # Keep the robot running
- try:
- while True:
- await asyncio.sleep(1)
- except KeyboardInterrupt:
- logger.info("Keyboard interrupt received")
- finally:
- await robot.stop() # type: ignore[func-returns-value]
-
-
-if __name__ == "__main__":
- asyncio.run(run_piper_arm())
diff --git a/dimos/robot/agilex/run.py b/dimos/robot/agilex/run.py
deleted file mode 100644
index 64e0ae5470..0000000000
--- a/dimos/robot/agilex/run.py
+++ /dev/null
@@ -1,190 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Run script for Piper Arm robot with Claude agent integration.
-Provides manipulation capabilities with natural language interface.
-"""
-
-import asyncio
-import os
-import sys
-
-from dotenv import load_dotenv
-import reactivex as rx
-import reactivex.operators as ops
-
-from dimos.agents_deprecated.claude_agent import ClaudeAgent
-from dimos.robot.agilex.piper_arm import PiperArmRobot
-from dimos.skills.kill_skill import KillSkill
-from dimos.skills.manipulation.pick_and_place import PickAndPlace
-from dimos.stream.audio.pipelines import stt, tts
-from dimos.utils.logging_config import setup_logger
-from dimos.web.robot_web_interface import RobotWebInterface
-
-logger = setup_logger()
-
-# Load environment variables
-load_dotenv()
-
-# System prompt for the Piper Arm manipulation agent
-SYSTEM_PROMPT = """You are an intelligent robotic assistant controlling a Piper Arm robot with advanced manipulation capabilities. Your primary role is to help users with pick and place tasks using natural language understanding.
-
-## Your Capabilities:
-1. **Visual Perception**: You have access to a ZED stereo camera that provides RGB and depth information
-2. **Object Manipulation**: You can pick up and place objects using a 6-DOF robotic arm with a gripper
-3. **Language Understanding**: You use the Qwen vision-language model to identify objects and locations from natural language descriptions
-
-## Available Skills:
-- **PickAndPlace**: Execute pick and place operations based on object and location descriptions
- - Pick only: "Pick up the red mug"
- - Pick and place: "Move the book to the shelf"
-- **KillSkill**: Stop any currently running skill
-
-## Guidelines:
-1. **Safety First**: Always ensure safe operation. If unsure about an object's graspability or a placement location's stability, ask for clarification
-2. **Clear Communication**: Explain what you're doing and ask for confirmation when needed
-3. **Error Handling**: If a task fails, explain why and suggest alternatives
-4. **Precision**: When users give specific object descriptions, use them exactly as provided to the vision model
-
-## Interaction Examples:
-- User: "Pick up the coffee mug"
- You: "I'll pick up the coffee mug for you." [Execute PickAndPlace with object_query="coffee mug"]
-
-- User: "Put the toy on the table"
- You: "I'll place the toy on the table." [Execute PickAndPlace with object_query="toy", target_query="on the table"]
-
-- User: "What do you see?"
-
-Remember: You're here to assist with manipulation tasks. Be helpful, precise, and always prioritize safe operation of the robot."""
-
-
-def main(): # type: ignore[no-untyped-def]
- """Main entry point."""
- print("\n" + "=" * 60)
- print("Piper Arm Robot with Claude Agent")
- print("=" * 60)
- print("\nThis system integrates:")
- print(" - Piper Arm 6-DOF robot")
- print(" - ZED stereo camera")
- print(" - Claude AI for natural language understanding")
- print(" - Qwen VLM for visual object detection")
- print(" - Web interface with text and voice input")
- print(" - Foxglove visualization via LCM")
- print("\nStarting system...\n")
-
- # Check for API key
- if not os.getenv("ANTHROPIC_API_KEY"):
- print("WARNING: ANTHROPIC_API_KEY not found in environment")
- print("Please set your API key in .env file or environment")
- sys.exit(1)
-
- logger.info("Starting Piper Arm Robot with Agent")
-
- # Create robot instance
- robot = PiperArmRobot() # type: ignore[abstract]
-
- try:
- # Start the robot (this is async, so we need asyncio.run)
- logger.info("Initializing robot...")
- asyncio.run(robot.start())
- logger.info("Robot initialized successfully")
-
- # Set up skill library
- skills = robot.get_skills() # type: ignore[no-untyped-call]
- skills.add(PickAndPlace)
- skills.add(KillSkill)
-
- # Create skill instances
- skills.create_instance("PickAndPlace", robot=robot)
- skills.create_instance("KillSkill", robot=robot, skill_library=skills)
-
- logger.info(f"Skills registered: {[skill.__name__ for skill in skills.get_class_skills()]}")
-
- # Set up streams for agent and web interface
- agent_response_subject = rx.subject.Subject() # type: ignore[var-annotated]
- agent_response_stream = agent_response_subject.pipe(ops.share())
- audio_subject = rx.subject.Subject() # type: ignore[var-annotated]
-
- # Set up streams for web interface
- streams = {} # type: ignore[var-annotated]
-
- text_streams = {
- "agent_responses": agent_response_stream,
- }
-
- # Create web interface first (needed for agent)
- try:
- web_interface = RobotWebInterface(
- port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams
- )
- logger.info("Web interface created successfully")
- except Exception as e:
- logger.error(f"Failed to create web interface: {e}")
- raise
-
- # Set up speech-to-text
- stt_node = stt() # type: ignore[no-untyped-call]
- stt_node.consume_audio(audio_subject.pipe(ops.share()))
-
- # Create Claude agent
- agent = ClaudeAgent(
- dev_name="piper_arm_agent",
- input_query_stream=web_interface.query_stream, # Use text input from web interface
- # input_query_stream=stt_node.emit_text(), # Uncomment to use voice input
- skills=skills,
- system_query=SYSTEM_PROMPT,
- model_name="claude-3-5-haiku-latest",
- thinking_budget_tokens=0,
- max_output_tokens_per_request=4096,
- )
-
- # Subscribe to agent responses
- agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x))
-
- # Set up text-to-speech for agent responses
- tts_node = tts() # type: ignore[no-untyped-call]
- tts_node.consume_text(agent.get_response_observable())
-
- logger.info("=" * 60)
- logger.info("Piper Arm Agent Ready!")
- logger.info("Web interface available at: http://localhost:5555")
- logger.info("Foxglove visualization available at: ws://localhost:8765")
- logger.info("You can:")
- logger.info(" - Type commands in the web interface")
- logger.info(" - Use voice commands")
- logger.info(" - Ask the robot to pick up objects")
- logger.info(" - Ask the robot to move objects to locations")
- logger.info("=" * 60)
-
- # Run web interface (this blocks)
- web_interface.run()
-
- except KeyboardInterrupt:
- logger.info("Keyboard interrupt received")
- except Exception as e:
- logger.error(f"Error running robot: {e}")
- import traceback
-
- traceback.print_exc()
- finally:
- logger.info("Shutting down...")
- # Stop the robot (this is also async)
- robot.stop()
- logger.info("Robot stopped")
-
-
-if __name__ == "__main__":
- main() # type: ignore[no-untyped-call]
diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py
index f989098f05..c241485472 100644
--- a/dimos/robot/all_blueprints.py
+++ b/dimos/robot/all_blueprints.py
@@ -19,8 +19,10 @@
"unitree-go2": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:nav",
"unitree-go2-basic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:basic",
"unitree-go2-nav": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:nav",
+ "unitree-go2-ros": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:ros",
"unitree-go2-detection": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:detection",
"unitree-go2-spatial": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:spatial",
+ "unitree-go2-temporal-memory": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:temporal_memory",
"unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic",
"unitree-go2-agentic-mcp": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_mcp",
"unitree-go2-agentic-ollama": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_ollama",
@@ -36,23 +38,21 @@
"unitree-g1-joystick": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:with_joystick",
"unitree-g1-full": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:full_featured",
"unitree-g1-detection": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:detection",
- # xArm manipulator blueprints
- "xarm-servo": "dimos.hardware.manipulators.xarm.xarm_blueprints:xarm_servo",
- "xarm5-servo": "dimos.hardware.manipulators.xarm.xarm_blueprints:xarm5_servo",
- "xarm7-servo": "dimos.hardware.manipulators.xarm.xarm_blueprints:xarm7_servo",
- "xarm-cartesian": "dimos.hardware.manipulators.xarm.xarm_blueprints:xarm_cartesian",
- "xarm-trajectory": "dimos.hardware.manipulators.xarm.xarm_blueprints:xarm_trajectory",
- # Piper manipulator blueprints
- "piper-servo": "dimos.hardware.manipulators.piper.piper_blueprints:piper_servo",
- "piper-cartesian": "dimos.hardware.manipulators.piper.piper_blueprints:piper_cartesian",
- "piper-trajectory": "dimos.hardware.manipulators.piper.piper_blueprints:piper_trajectory",
+ # Control orchestrator blueprints
+ "orchestrator-mock": "dimos.control.blueprints:orchestrator_mock",
+ "orchestrator-xarm7": "dimos.control.blueprints:orchestrator_xarm7",
+ "orchestrator-xarm6": "dimos.control.blueprints:orchestrator_xarm6",
+ "orchestrator-piper": "dimos.control.blueprints:orchestrator_piper",
+ "orchestrator-dual-mock": "dimos.control.blueprints:orchestrator_dual_mock",
+ "orchestrator-dual-xarm": "dimos.control.blueprints:orchestrator_dual_xarm",
+ "orchestrator-piper-xarm": "dimos.control.blueprints:orchestrator_piper_xarm",
# Demo blueprints
+ "demo-camera": "dimos.hardware.sensors.camera.module:demo_camera",
"demo-osm": "dimos.mapping.osm.demo_osm:demo_osm",
"demo-skill": "dimos.agents.skills.demo_skill:demo_skill",
"demo-gps-nav": "dimos.agents.skills.demo_gps_nav:demo_gps_nav_skill",
"demo-google-maps-skill": "dimos.agents.skills.demo_google_maps_skill:demo_google_maps_skill",
- "demo-remapping": "dimos.robot.unitree_webrtc.demo_remapping:remapping",
- "demo-remapping-transport": "dimos.robot.unitree_webrtc.demo_remapping:remapping_and_transport",
+ "demo-object-scene-registration": "dimos.perception.demo_object_scene_registration:demo_object_scene_registration",
"demo-error-on-name-conflicts": "dimos.robot.unitree_webrtc.demo_error_on_name_conflicts:blueprint",
}
@@ -83,10 +83,8 @@
"wavefront_frontier_explorer": "dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector",
"websocket_vis": "dimos.web.websocket_vis.websocket_vis_module",
"web_input": "dimos.agents.cli.web",
- # xArm manipulator modules
- "xarm_driver": "dimos.hardware.manipulators.xarm.xarm_driver",
- "cartesian_motion_controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller",
- "joint_trajectory_controller": "dimos.manipulation.control.trajectory_controller.joint_trajectory_controller",
+ # Control orchestrator module
+ "control_orchestrator": "dimos.control.orchestrator",
}
diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py
index 5cf09e02e3..a000502abc 100644
--- a/dimos/robot/cli/dimos.py
+++ b/dimos/robot/cli/dimos.py
@@ -184,7 +184,10 @@ def humancli(ctx: typer.Context) -> None:
@topic_app.command()
def echo(
topic: str = typer.Argument(..., help="Topic name to listen on (e.g., /goal_request)"),
- type_name: str = typer.Argument(..., help="Message type (e.g., PoseStamped)"),
+ type_name: str | None = typer.Argument(
+ None,
+ help="Optional message type (e.g., PoseStamped). If omitted, infer from '/topic#pkg.Msg'.",
+ ),
) -> None:
topic_echo(topic, type_name)
diff --git a/dimos/robot/cli/topic.py b/dimos/robot/cli/topic.py
index bdd1a29ae6..582099c4b6 100644
--- a/dimos/robot/cli/topic.py
+++ b/dimos/robot/cli/topic.py
@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
+
import importlib
+import re
import time
import typer
from dimos.core.transport import LCMTransport, pLCMTransport
+from dimos.protocol.pubsub.lcmpubsub import LCMPubSubBase
_modules_to_try = [
"dimos.msgs.geometry_msgs",
@@ -42,24 +46,53 @@ def _resolve_type(type_name: str) -> type:
raise ValueError(f"Could not find type '{type_name}' in any known message modules")
-def topic_echo(topic: str, type_name: str) -> None:
- msg_type = _resolve_type(type_name)
- use_pickled = getattr(msg_type, "lcm_encode", None) is None
- transport: pLCMTransport[object] | LCMTransport[object] = (
- pLCMTransport(topic) if use_pickled else LCMTransport(topic, msg_type)
- )
-
- def _on_message(msg: object) -> None:
- print(msg)
+def topic_echo(topic: str, type_name: str | None) -> None:
+ # Explicit mode (legacy): unchanged.
+ if type_name is not None:
+ msg_type = _resolve_type(type_name)
+ use_pickled = getattr(msg_type, "lcm_encode", None) is None
+ transport: pLCMTransport[object] | LCMTransport[object] = (
+ pLCMTransport(topic) if use_pickled else LCMTransport(topic, msg_type)
+ )
- transport.subscribe(_on_message)
+ def _on_message(msg: object) -> None:
+ print(msg)
- typer.echo(f"Listening on {topic} for {type_name} messages... (Ctrl+C to stop)")
+ transport.subscribe(_on_message)
+ typer.echo(f"Listening on {topic} for {type_name} messages... (Ctrl+C to stop)")
+ try:
+ while True:
+ time.sleep(0.1)
+ except KeyboardInterrupt:
+ typer.echo("\nStopped.")
+ return
+
+ # Inferred typed mode: listen on /topic#pkg.Msg and decode from the msg_name suffix.
+ bus = LCMPubSubBase(autoconf=True)
+ bus.start() # starts threaded handle loop
+
+ typed_pattern = rf"^{re.escape(topic)}#.*"
+
+ def on_msg(channel: str, data: bytes) -> None:
+ _, msg_name = channel.split("#", 1) # e.g. "nav_msgs.Odometry"
+ pkg, cls_name = msg_name.split(".", 1) # "nav_msgs", "Odometry"
+ module = importlib.import_module(f"dimos.msgs.{pkg}")
+ cls = getattr(module, cls_name)
+ print(cls.lcm_decode(data))
+
+ assert bus.l is not None
+ bus.l.subscribe(typed_pattern, on_msg)
+
+ typer.echo(
+ f"Listening on {topic} (inferring from typed LCM channels like '{topic}#pkg.Msg')... "
+ "(Ctrl+C to stop)"
+ )
try:
while True:
time.sleep(0.1)
except KeyboardInterrupt:
+ bus.stop()
typer.echo("\nStopped.")
diff --git a/dimos/robot/test_all_blueprints.py b/dimos/robot/test_all_blueprints.py
new file mode 100644
index 0000000000..7e5fa6970c
--- /dev/null
+++ b/dimos/robot/test_all_blueprints.py
@@ -0,0 +1,44 @@
+# Copyright 2025-2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from dimos.core.blueprints import ModuleBlueprintSet
+from dimos.robot.all_blueprints import all_blueprints, get_blueprint_by_name
+
+# Optional dependencies that are allowed to be missing
+OPTIONAL_DEPENDENCIES = {"pyrealsense2", "geometry_msgs", "turbojpeg"}
+OPTIONAL_ERROR_SUBSTRINGS = {
+ "Unable to locate turbojpeg library automatically",
+}
+
+
+@pytest.mark.integration
+@pytest.mark.parametrize("blueprint_name", all_blueprints.keys())
+def test_all_blueprints_are_valid(blueprint_name: str) -> None:
+ """Test that all blueprints in all_blueprints are valid ModuleBlueprintSet instances."""
+ try:
+ blueprint = get_blueprint_by_name(blueprint_name)
+ except ModuleNotFoundError as e:
+ if e.name in OPTIONAL_DEPENDENCIES:
+ pytest.skip(f"Skipping due to missing optional dependency: {e.name}")
+ raise
+ except Exception as e:
+ message = str(e)
+ if any(substring in message for substring in OPTIONAL_ERROR_SUBSTRINGS):
+ pytest.skip(f"Skipping due to missing optional dependency: {message}")
+ raise
+ assert isinstance(blueprint, ModuleBlueprintSet), (
+ f"Blueprint '{blueprint_name}' is not a ModuleBlueprintSet, got {type(blueprint)}"
+ )
diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py
index bef0c0b127..8f4a138320 100644
--- a/dimos/robot/unitree/connection/connection.py
+++ b/dimos/robot/unitree/connection/connection.py
@@ -37,9 +37,9 @@
from dimos.core import rpc
from dimos.core.resource import Resource
from dimos.msgs.geometry_msgs import Pose, Transform, Twist
-from dimos.msgs.sensor_msgs import Image
+from dimos.msgs.sensor_msgs import Image, PointCloud2
from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ImageFormat
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.robot.unitree_webrtc.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar
from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg
from dimos.robot.unitree_webrtc.type.odometry import Odometry
from dimos.utils.decorators.decorators import simple_mcache
@@ -235,7 +235,7 @@ def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-de
return future.result()
@simple_mcache
- def raw_lidar_stream(self) -> Observable[LidarMessage]:
+ def raw_lidar_stream(self) -> Observable[RawLidarMsg]:
return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"]))
@simple_mcache
@@ -243,12 +243,8 @@ def raw_odom_stream(self) -> Observable[Pose]:
return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"]))
@simple_mcache
- def lidar_stream(self) -> Observable[LidarMessage]:
- return backpressure(
- self.raw_lidar_stream().pipe(
- ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) # type: ignore[arg-type]
- )
- )
+ def lidar_stream(self) -> Observable[PointCloud2]:
+ return backpressure(self.raw_lidar_stream().pipe(ops.map(pointcloud2_from_webrtc_lidar)))
@simple_mcache
def tf_stream(self) -> Observable[Transform]:
diff --git a/dimos/robot/unitree/connection/g1sim.py b/dimos/robot/unitree/connection/g1sim.py
index d72e7d17f6..cd4c3e4505 100644
--- a/dimos/robot/unitree/connection/g1sim.py
+++ b/dimos/robot/unitree/connection/g1sim.py
@@ -27,7 +27,7 @@
Twist,
Vector3,
)
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
from dimos.robot.unitree_webrtc.type.odometry import Odometry as SimOdometry
from dimos.utils.logging_config import setup_logger
@@ -39,7 +39,7 @@
class G1SimConnection(Module):
cmd_vel: In[Twist]
- lidar: Out[LidarMessage]
+ lidar: Out[PointCloud2]
odom: Out[PoseStamped]
ip: str | None
_global_config: GlobalConfig
diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py
index 34f81e2bbf..96a54117c8 100644
--- a/dimos/robot/unitree/connection/go2.py
+++ b/dimos/robot/unitree/connection/go2.py
@@ -13,7 +13,6 @@
# limitations under the License.
import logging
-from pathlib import Path
from threading import Thread
import time
from typing import Any, Protocol
@@ -35,17 +34,13 @@
Vector3,
)
from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
+from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ImageFormat
from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
from dimos.utils.data import get_data
from dimos.utils.decorators.decorators import simple_mcache
-from dimos.utils.logging_config import setup_logger
from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage
-logger = setup_logger(level=logging.INFO)
-
-# URDF path for Go2 robot
-_GO2_URDF = Path(__file__).parent.parent / "go2" / "go2.urdf"
+logger = logging.getLogger(__name__)
class Go2ConnectionProtocol(Protocol):
@@ -120,7 +115,22 @@ def odom_stream(self): # type: ignore[no-untyped-def]
# we don't have raw video stream in the data set
@simple_mcache
def video_stream(self): # type: ignore[no-untyped-def]
- video_store = TimedSensorReplay(f"{self.dir_name}/video") # type: ignore[var-annotated]
+ # Legacy Unitree recordings can have RGB bytes that were tagged/assumed as BGR.
+ # Fix at replay-time by coercing everything to RGB before publishing/logging.
+ def _autocast_video(x): # type: ignore[no-untyped-def]
+ # If the old recording tagged it as BGR, relabel to RGB (do NOT channel-swap again).
+ if isinstance(x, Image):
+ if x.format == ImageFormat.BGR:
+ x.format = ImageFormat.RGB
+ if not x.frame_id:
+ x.frame_id = "camera_optical"
+ return x
+
+ # Some recordings may store raw arrays or frame wrappers.
+ arr = x.to_ndarray(format="rgb24") if hasattr(x, "to_ndarray") else x
+ return Image.from_numpy(arr, format=ImageFormat.RGB, frame_id="camera_optical")
+
+ video_store = TimedSensorReplay(f"{self.dir_name}/video", autocast=_autocast_video) # type: ignore[var-annotated]
return video_store.stream(**self.replay_config) # type: ignore[arg-type]
def move(self, twist: Twist, duration: float = 0.0) -> bool:
@@ -135,7 +145,7 @@ class GO2Connection(Module, spec.Camera, spec.Pointcloud):
cmd_vel: In[Twist]
pointcloud: Out[PointCloud2]
odom: Out[PoseStamped]
- lidar: Out[LidarMessage]
+ lidar: Out[PointCloud2]
color_image: Out[Image]
camera_info: Out[CameraInfo]
@@ -196,13 +206,14 @@ def start(self) -> None:
self.connection.start()
- # Initialize Rerun world frame and load URDF (only if Rerun backend)
+ # Connect this worker process to Rerun if it will log sensor data.
if self._global_config.viewer_backend.startswith("rerun"):
- self._init_rerun_world()
+ connect_rerun(global_config=self._global_config)
def onimage(image: Image) -> None:
self.color_image.publish(image)
- rr.log("world/robot/camera/rgb", image.to_rerun())
+ if self._global_config.viewer_backend.startswith("rerun"):
+ rr.log("world/robot/camera/rgb", image.to_rerun())
self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish))
self._disposables.add(self.connection.odom_stream().subscribe(self._publish_tf))
@@ -218,45 +229,6 @@ def onimage(image: Image) -> None:
self.standup()
# self.record("go2_bigoffice")
- def _init_rerun_world(self) -> None:
- """Set up Rerun world frame, load URDF, and static assets.
-
- Does NOT compose blueprint - that's handled by ModuleBlueprintSet.build().
- """
- connect_rerun(global_config=self._global_config)
-
- # Set up world coordinate system AND register it as a named frame
- # This is KEY - it connects entity paths to the named frame system
- rr.log(
- "world",
- rr.ViewCoordinates.RIGHT_HAND_Z_UP,
- rr.CoordinateFrame("world"), # type: ignore[attr-defined]
- static=True,
- )
-
- # Bridge the named frame "world" to the implicit frame hierarchy "tf#/world"
- # This connects TF named frames to entity path hierarchy
- rr.log(
- "world",
- rr.Transform3D(
- parent_frame="world", # type: ignore[call-arg]
- child_frame="tf#/world", # type: ignore[call-arg]
- ),
- static=True,
- )
-
- # Load robot URDF
- if _GO2_URDF.exists():
- rr.log_file_from_path(
- str(_GO2_URDF),
- entity_path_prefix="world/robot",
- static=True,
- )
- logger.info(f"Loaded URDF from {_GO2_URDF}")
-
- # Log static camera pinhole (for frustum)
- rr.log("world/robot/camera", _camera_info_static().to_rerun(), static=True)
-
@rpc
def stop(self) -> None:
self.liedown()
@@ -299,42 +271,6 @@ def _publish_tf(self, msg: PoseStamped) -> None:
if self.odom.transport:
self.odom.publish(msg)
- # Log to Rerun: robot pose (relative to parent entity "world")
- rr.log(
- "world/robot",
- rr.Transform3D(
- translation=[msg.x, msg.y, msg.z],
- rotation=rr.Quaternion(
- xyzw=[
- msg.orientation.x,
- msg.orientation.y,
- msg.orientation.z,
- msg.orientation.w,
- ]
- ),
- ),
- )
- # Log axes as a child entity for visualization
- rr.log("world/robot/axes", rr.TransformAxes3D(0.5)) # type: ignore[attr-defined]
-
- # Log camera transform (compose base_link -> camera_link -> camera_optical)
- # transforms[1] is camera_link, transforms[2] is camera_optical
- cam_tf = transforms[1] + transforms[2] # compose transforms
- rr.log(
- "world/robot/camera",
- rr.Transform3D(
- translation=[cam_tf.translation.x, cam_tf.translation.y, cam_tf.translation.z],
- rotation=rr.Quaternion(
- xyzw=[
- cam_tf.rotation.x,
- cam_tf.rotation.y,
- cam_tf.rotation.z,
- cam_tf.rotation.w,
- ]
- ),
- ),
- )
-
def publish_camera_info(self) -> None:
while True:
self.camera_info.publish(_camera_info_static())
diff --git a/dimos/robot/unitree_webrtc/modular/detect.py b/dimos/robot/unitree_webrtc/modular/detect.py
index 2a266ef820..8f92d15e81 100644
--- a/dimos/robot/unitree_webrtc/modular/detect.py
+++ b/dimos/robot/unitree_webrtc/modular/detect.py
@@ -16,9 +16,9 @@
from dimos_lcm.sensor_msgs import CameraInfo
-from dimos.msgs.sensor_msgs import Image
+from dimos.msgs.sensor_msgs import Image, PointCloud2
from dimos.msgs.std_msgs import Header
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.robot.unitree_webrtc.type.lidar import pointcloud2_from_webrtc_lidar
from dimos.robot.unitree_webrtc.type.odometry import Odometry
image_resize_factor = 1
@@ -102,7 +102,7 @@ def transform_chain(odom_frame: Odometry) -> list: # type: ignore[type-arg]
def broadcast( # type: ignore[no-untyped-def]
timestamp: float,
- lidar_frame: LidarMessage,
+ lidar_frame: PointCloud2,
video_frame: Image,
odom_frame: Odometry,
detections,
@@ -115,7 +115,7 @@ def broadcast( # type: ignore[no-untyped-def]
from dimos.core import LCMTransport
from dimos.msgs.geometry_msgs import PoseStamped
- lidar_transport = LCMTransport("/lidar", LidarMessage) # type: ignore[var-annotated]
+ lidar_transport = LCMTransport("/lidar", PointCloud2) # type: ignore[var-annotated]
odom_transport = LCMTransport("/odom", PoseStamped) # type: ignore[var-annotated]
video_transport = LCMTransport("/image", Image) # type: ignore[var-annotated]
camera_info_transport = LCMTransport("/camera_info", CameraInfo) # type: ignore[var-annotated]
@@ -141,14 +141,15 @@ def process_data(): # type: ignore[no-untyped-def]
Detection2DModule,
build_imageannotations,
)
- from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
from dimos.robot.unitree_webrtc.type.odometry import Odometry
from dimos.utils.data import get_data
from dimos.utils.testing import TimedSensorReplay
get_data("unitree_office_walk")
target = 1751591272.9654856
- lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg)
+ lidar_store = TimedSensorReplay(
+ "unitree_office_walk/lidar", autocast=pointcloud2_from_webrtc_lidar
+ )
video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy)
odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg)
diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py
index 586f4d0ea7..d4a7736fc2 100644
--- a/dimos/robot/unitree_webrtc/mujoco_connection.py
+++ b/dimos/robot/unitree_webrtc/mujoco_connection.py
@@ -15,6 +15,7 @@
# limitations under the License.
+import atexit
import base64
from collections.abc import Callable
import functools
@@ -25,6 +26,7 @@
import threading
import time
from typing import Any, TypeVar
+import weakref
import numpy as np
from numpy.typing import NDArray
@@ -34,10 +36,16 @@
from dimos.core.global_config import GlobalConfig
from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3
-from dimos.msgs.sensor_msgs import Image
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import CameraInfo, Image, ImageFormat, PointCloud2
from dimos.robot.unitree_webrtc.type.odometry import Odometry
-from dimos.simulation.mujoco.constants import LAUNCHER_PATH, LIDAR_FPS, VIDEO_FPS
+from dimos.simulation.mujoco.constants import (
+ LAUNCHER_PATH,
+ LIDAR_FPS,
+ VIDEO_CAMERA_FOV,
+ VIDEO_FPS,
+ VIDEO_HEIGHT,
+ VIDEO_WIDTH,
+)
from dimos.simulation.mujoco.shared_memory import ShmWriter
from dimos.utils.data import get_data
from dimos.utils.logging_config import setup_logger
@@ -61,9 +69,11 @@ def __init__(self, global_config: GlobalConfig) -> None:
# Pre-download the mujoco_sim data.
get_data("mujoco_sim")
- # Trigger the download of the mujoco_menajerie package. This is so it
+ # Trigger the download of the mujoco_menagerie package. This is so it
# doesn't trigger in the mujoco process where it can time out.
- import mujoco_playground
+ from mujoco_playground._src import mjx_env
+
+ mjx_env.ensure_menagerie_exists()
self.global_config = global_config
self.process: subprocess.Popen[bytes] | None = None
@@ -77,6 +87,32 @@ def __init__(self, global_config: GlobalConfig) -> None:
self._stop_events: list[threading.Event] = []
self._is_cleaned_up = False
+ @staticmethod
+ def _compute_camera_info() -> CameraInfo:
+ """Compute camera intrinsics from MuJoCo camera parameters.
+
+ Uses pinhole camera model: f = height / (2 * tan(fovy / 2))
+ """
+ import math
+
+ fovy = math.radians(VIDEO_CAMERA_FOV)
+ f = VIDEO_HEIGHT / (2 * math.tan(fovy / 2))
+ cx = VIDEO_WIDTH / 2.0
+ cy = VIDEO_HEIGHT / 2.0
+
+ return CameraInfo(
+ frame_id="camera_optical",
+ height=VIDEO_HEIGHT,
+ width=VIDEO_WIDTH,
+ distortion_model="plumb_bob",
+ D=[0.0, 0.0, 0.0, 0.0, 0.0],
+ K=[f, 0.0, cx, 0.0, f, cy, 0.0, 0.0, 1.0],
+ R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
+ P=[f, 0.0, cx, 0.0, 0.0, f, cy, 0.0, 0.0, 0.0, 1.0, 0.0],
+ )
+
+ camera_info_static: CameraInfo = _compute_camera_info()
+
def start(self) -> None:
self.shm_data = ShmWriter()
@@ -87,6 +123,7 @@ def start(self) -> None:
try:
# mjpython must be used macOS (because of launch_passive inside mujoco_process.py)
executable = sys.executable if sys.platform != "darwin" else "mjpython"
+
self.process = subprocess.Popen(
[executable, str(LAUNCHER_PATH), config_pickle, shm_names_json],
)
@@ -106,6 +143,16 @@ def start(self) -> None:
raise RuntimeError(f"MuJoCo process failed to start (exit code {exit_code})")
if self.shm_data.is_ready():
logger.info("MuJoCo process started successfully")
+ # Register atexit handler to ensure subprocess is cleaned up
+ # Use weakref to avoid preventing garbage collection
+ weak_self = weakref.ref(self)
+
+ def cleanup_on_exit() -> None:
+ instance = weak_self()
+ if instance is not None:
+ instance.stop()
+
+ atexit.register(cleanup_on_exit)
return
time.sleep(0.1)
@@ -212,7 +259,7 @@ def get_odom_message(self) -> Odometry | None:
return None
- def get_lidar_message(self) -> LidarMessage | None:
+ def get_lidar_message(self) -> PointCloud2 | None:
if self.shm_data is None:
return None
@@ -261,7 +308,7 @@ def dispose() -> None:
return Observable(on_subscribe)
@functools.cache
- def lidar_stream(self) -> Observable[LidarMessage]:
+ def lidar_stream(self) -> Observable[PointCloud2]:
return self._create_stream(self.get_lidar_message, LIDAR_FPS, "Lidar")
@functools.cache
@@ -272,7 +319,8 @@ def odom_stream(self) -> Observable[Odometry]:
def video_stream(self) -> Observable[Image]:
def get_video_as_image() -> Image | None:
frame = self.get_video_frame()
- return Image.from_numpy(frame) if frame is not None else None
+ # MuJoCo renderer returns RGB uint8 frames; Image.from_numpy defaults to BGR.
+ return Image.from_numpy(frame, format=ImageFormat.RGB) if frame is not None else None
return self._create_stream(get_video_as_image, VIDEO_FPS, "Video")
diff --git a/dimos/robot/unitree_webrtc/testing/mock.py b/dimos/robot/unitree_webrtc/testing/mock.py
index 34ca390842..2af1754cb4 100644
--- a/dimos/robot/unitree_webrtc/testing/mock.py
+++ b/dimos/robot/unitree_webrtc/testing/mock.py
@@ -21,7 +21,8 @@
from reactivex import from_iterable, interval, operators as ops
from reactivex.observable import Observable
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage, RawLidarMsg
+from dimos.msgs.sensor_msgs import PointCloud2
+from dimos.robot.unitree_webrtc.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar
class Mock:
@@ -32,16 +33,16 @@ def __init__(self, root: str = "office", autocast: bool = True) -> None:
self.cnt = 0
@overload
- def load(self, name: int | str, /) -> LidarMessage: ...
+ def load(self, name: int | str, /) -> PointCloud2: ...
@overload
- def load(self, *names: int | str) -> list[LidarMessage]: ...
+ def load(self, *names: int | str) -> list[PointCloud2]: ...
- def load(self, *names: int | str) -> LidarMessage | list[LidarMessage]:
+ def load(self, *names: int | str) -> PointCloud2 | list[PointCloud2]:
if len(names) == 1:
return self.load_one(names[0])
return list(map(lambda name: self.load_one(name), names))
- def load_one(self, name: int | str) -> LidarMessage:
+ def load_one(self, name: int | str) -> PointCloud2:
if isinstance(name, int):
file_name = f"/lidar_data_{name:03d}.pickle"
else:
@@ -49,9 +50,9 @@ def load_one(self, name: int | str) -> LidarMessage:
full_path = self.root + file_name
with open(full_path, "rb") as f:
- return LidarMessage.from_msg(cast("RawLidarMsg", pickle.load(f)))
+ return pointcloud2_from_webrtc_lidar(cast("RawLidarMsg", pickle.load(f)))
- def iterate(self) -> Iterator[LidarMessage]:
+ def iterate(self) -> Iterator[PointCloud2]:
pattern = os.path.join(self.root, "lidar_data_*.pickle")
print("loading data", pattern)
for file_path in sorted(glob.glob(pattern)):
@@ -67,7 +68,7 @@ def stream(self, rate_hz: float = 10.0): # type: ignore[no-untyped-def]
ops.map(lambda x: x[0] if isinstance(x, tuple) else x),
)
- def save_stream(self, observable: Observable[LidarMessage]): # type: ignore[no-untyped-def]
+ def save_stream(self, observable: Observable[PointCloud2]): # type: ignore[no-untyped-def]
return observable.pipe(ops.map(lambda frame: self.save_one(frame))) # type: ignore[no-untyped-call]
def save(self, *frames): # type: ignore[no-untyped-def]
@@ -83,9 +84,8 @@ def save_one(self, frame): # type: ignore[no-untyped-def]
if os.path.isfile(full_path):
raise Exception(f"file {full_path} exists")
- if frame.__class__ == LidarMessage:
- frame = frame.raw_msg
-
+ # Note: This saves the PointCloud2 directly. For raw message saving,
+ # use the raw message before conversion.
with open(full_path, "wb") as f:
pickle.dump(frame, f)
diff --git a/dimos/robot/unitree_webrtc/testing/test_actors.py b/dimos/robot/unitree_webrtc/testing/test_actors.py
index 7e79ca24cc..def89346e8 100644
--- a/dimos/robot/unitree_webrtc/testing/test_actors.py
+++ b/dimos/robot/unitree_webrtc/testing/test_actors.py
@@ -19,7 +19,7 @@
from dimos import core
from dimos.core import Module, rpc
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
from dimos.robot.unitree_webrtc.type.map import Map as Mapper
@@ -95,7 +95,7 @@ def test_basic(dimos) -> None:
@pytest.mark.tool
def test_mapper_start(dimos) -> None:
mapper = dimos.deploy(Mapper)
- mapper.lidar.transport = core.LCMTransport("/lidar", LidarMessage)
+ mapper.lidar.transport = core.LCMTransport("/lidar", PointCloud2)
print("start res", mapper.start().result())
diff --git a/dimos/robot/unitree_webrtc/testing/test_mock.py b/dimos/robot/unitree_webrtc/testing/test_mock.py
deleted file mode 100644
index 0765894409..0000000000
--- a/dimos/robot/unitree_webrtc/testing/test_mock.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import time
-
-import pytest
-
-from dimos.robot.unitree_webrtc.testing.mock import Mock
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
-
-
-@pytest.mark.needsdata
-def test_mock_load_cast() -> None:
- mock = Mock("test")
-
- # Load a frame with type casting
- frame = mock.load("a")
-
- # Verify it's a LidarMessage object
- assert frame.__class__.__name__ == "LidarMessage"
- assert hasattr(frame, "timestamp")
- assert hasattr(frame, "origin")
- assert hasattr(frame, "resolution")
- assert hasattr(frame, "pointcloud")
-
- # Verify pointcloud has points
- assert frame.pointcloud.has_points()
- assert len(frame.pointcloud.points) > 0
-
-
-@pytest.mark.needsdata
-def test_mock_iterate() -> None:
- """Test the iterate method of the Mock class."""
- mock = Mock("office")
-
- # Test iterate method
- frames = list(mock.iterate())
- assert len(frames) > 0
- for frame in frames:
- assert isinstance(frame, LidarMessage)
- assert frame.pointcloud.has_points()
-
-
-@pytest.mark.needsdata
-def test_mock_stream() -> None:
- frames = []
- sub1 = Mock("office").stream(rate_hz=30.0).subscribe(on_next=frames.append)
- time.sleep(0.1)
- sub1.dispose()
-
- assert len(frames) >= 2
- assert isinstance(frames[0], LidarMessage)
diff --git a/dimos/robot/unitree_webrtc/testing/test_tooling.py b/dimos/robot/unitree_webrtc/testing/test_tooling.py
index 456d600879..50b689931e 100644
--- a/dimos/robot/unitree_webrtc/testing/test_tooling.py
+++ b/dimos/robot/unitree_webrtc/testing/test_tooling.py
@@ -16,7 +16,7 @@
import pytest
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.robot.unitree_webrtc.type.lidar import pointcloud2_from_webrtc_lidar
from dimos.robot.unitree_webrtc.type.odometry import Odometry
from dimos.utils.reactive import backpressure
from dimos.utils.testing import TimedSensorReplay
@@ -24,7 +24,7 @@
@pytest.mark.tool
def test_replay_all() -> None:
- lidar_store = TimedSensorReplay("unitree/lidar", autocast=LidarMessage.from_msg)
+ lidar_store = TimedSensorReplay("unitree/lidar", autocast=pointcloud2_from_webrtc_lidar)
odom_store = TimedSensorReplay("unitree/odom", autocast=Odometry.from_msg)
video_store = TimedSensorReplay("unitree/video")
diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py
index b598373a09..df2909dc38 100644
--- a/dimos/robot/unitree_webrtc/type/lidar.py
+++ b/dimos/robot/unitree_webrtc/type/lidar.py
@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Unitree WebRTC lidar message parsing utilities."""
+
import time
from typing import TypedDict
import numpy as np
import open3d as o3d # type: ignore[import-untyped]
-from dimos.msgs.geometry_msgs import Vector3
from dimos.msgs.sensor_msgs import PointCloud2
-from dimos.types.timestamped import to_human_readable
+
+# Backwards compatibility alias for pickled data
+LidarMessage = PointCloud2
class RawLidarPoints(TypedDict):
@@ -40,92 +43,32 @@ class RawLidarData(TypedDict):
class RawLidarMsg(TypedDict):
- """Static type definition for raw LIDAR message"""
+ """Static type definition for raw LIDAR message from Unitree WebRTC."""
type: str
topic: str
data: RawLidarData
-class LidarMessage(PointCloud2):
- resolution: float # we lose resolution when encoding PointCloud2
- origin: Vector3
- raw_msg: RawLidarMsg | None
- # _costmap: Optional[Costmap] = None # TODO: Fix after costmap migration
-
- def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def]
- super().__init__(
- pointcloud=kwargs.get("pointcloud"),
- ts=kwargs.get("ts"),
- frame_id="world",
- )
-
- self.origin = kwargs.get("origin") # type: ignore[assignment]
- self.resolution = kwargs.get("resolution", 0.05)
-
- @classmethod
- def from_msg(cls: type["LidarMessage"], raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": # type: ignore[no-untyped-def]
- data = raw_message["data"]
- points = data["data"]["points"]
- pointcloud = o3d.geometry.PointCloud()
- pointcloud.points = o3d.utility.Vector3dVector(points)
-
- origin = Vector3(data["origin"])
- # webrtc decoding via native decompression doesn't require us
- # to shift the pointcloud by it's origin
- #
- # pointcloud.translate((origin / 2).to_tuple())
- cls_data = {
- "origin": origin,
- "resolution": data["resolution"],
- "pointcloud": pointcloud,
- # - this is broken in unitree webrtc api "stamp":1.758148e+09
- "ts": time.time(), # data["stamp"],
- "raw_msg": raw_message,
- **kwargs,
- }
- return cls(**cls_data)
-
- def __repr__(self) -> str:
- return f"LidarMessage(ts={to_human_readable(self.ts)}, origin={self.origin}, resolution={self.resolution}, {self.pointcloud})"
-
- def __iadd__(self, other: "LidarMessage") -> "LidarMessage": # type: ignore[override]
- self.pointcloud += other.pointcloud
- return self
-
- def __add__(self, other: "LidarMessage") -> "LidarMessage": # type: ignore[override]
- # Determine which message is more recent
- if self.ts >= other.ts:
- ts = self.ts
- origin = self.origin
- resolution = self.resolution
- else:
- ts = other.ts
- origin = other.origin
- resolution = other.resolution
-
- # Return a new LidarMessage with combined data
- return LidarMessage( # type: ignore[attr-defined, no-any-return]
- ts=ts,
- origin=origin,
- resolution=resolution,
- pointcloud=self.pointcloud + other.pointcloud,
- ).estimate_normals()
-
- @property
- def o3d_geometry(self): # type: ignore[no-untyped-def]
- return self.pointcloud
-
- # TODO: Fix after costmap migration
- # def costmap(self, voxel_size: float = 0.2) -> Costmap:
- # if not self._costmap:
- # down_sampled_pointcloud = self.pointcloud.voxel_down_sample(voxel_size=voxel_size)
- # inflate_radius_m = 1.0 * voxel_size if voxel_size > self.resolution else 0.0
- # grid, origin_xy = pointcloud_to_costmap(
- # down_sampled_pointcloud,
- # resolution=self.resolution,
- # inflate_radius_m=inflate_radius_m,
- # )
- # self._costmap = Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.resolution)
- #
- # return self._costmap
+def pointcloud2_from_webrtc_lidar(raw_message: RawLidarMsg, ts: float | None = None) -> PointCloud2:
+ """Convert a raw Unitree WebRTC lidar message to PointCloud2.
+
+ Args:
+ raw_message: Raw lidar message from Unitree WebRTC API
+ ts: Optional timestamp override. If None, uses current time.
+
+ Returns:
+ PointCloud2 message with the lidar points
+ """
+ data = raw_message["data"]
+ points = data["data"]["points"]
+
+ pointcloud = o3d.geometry.PointCloud()
+ pointcloud.points = o3d.utility.Vector3dVector(points)
+
+ return PointCloud2(
+ pointcloud=pointcloud,
+ # webrtc stamp is broken (e.g., "stamp": 1.758148e+09), use current time
+ ts=ts if ts is not None else time.time(),
+ frame_id="world",
+ )
diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py
index 3bc1e61aef..f9abd96b88 100644
--- a/dimos/robot/unitree_webrtc/type/map.py
+++ b/dimos/robot/unitree_webrtc/type/map.py
@@ -28,12 +28,11 @@
from dimos.msgs.nav_msgs import OccupancyGrid
from dimos.msgs.sensor_msgs import PointCloud2
from dimos.robot.unitree.connection.go2 import Go2ConnectionProtocol
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
class Map(Module):
- lidar: In[LidarMessage]
- global_map: Out[LidarMessage]
+ lidar: In[PointCloud2]
+ global_map: Out[PointCloud2]
global_costmap: Out[OccupancyGrid]
_point_cloud_accumulator: PointCloudAccumulator
@@ -85,17 +84,9 @@ def to_PointCloud2(self) -> PointCloud2:
ts=time.time(),
)
- def to_lidar_message(self) -> LidarMessage:
- return LidarMessage(
- pointcloud=self._point_cloud_accumulator.get_point_cloud(),
- origin=[0.0, 0.0, 0.0],
- resolution=self.voxel_size,
- ts=time.time(),
- )
-
# TODO: Why is this RPC?
@rpc
- def add_frame(self, frame: LidarMessage) -> None:
+ def add_frame(self, frame: PointCloud2) -> None:
self._point_cloud_accumulator.add(frame.pointcloud)
@property
@@ -103,10 +94,10 @@ def o3d_geometry(self) -> o3d.geometry.PointCloud:
return self._point_cloud_accumulator.get_point_cloud()
def _publish(self, _: Any) -> None:
- self.global_map.publish(self.to_lidar_message())
+ self.global_map.publish(self.to_PointCloud2())
occupancygrid = general_occupancy(
- self.to_lidar_message(),
+ self.to_PointCloud2(),
resolution=self.cost_resolution,
min_height=self.min_height,
max_height=self.max_height,
@@ -127,7 +118,7 @@ def _publish(self, _: Any) -> None:
def deploy(dimos: DimosCluster, connection: Go2ConnectionProtocol): # type: ignore[no-untyped-def]
mapper = dimos.deploy(Map, global_publish_interval=1.0) # type: ignore[attr-defined]
- mapper.global_map.transport = LCMTransport("/global_map", LidarMessage)
+ mapper.global_map.transport = LCMTransport("/global_map", PointCloud2)
mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid)
mapper.lidar.connect(connection.pointcloud) # type: ignore[attr-defined]
mapper.start()
diff --git a/dimos/robot/unitree_webrtc/type/test_lidar.py b/dimos/robot/unitree_webrtc/type/test_lidar.py
index 0ad918409b..7543fe63a7 100644
--- a/dimos/robot/unitree_webrtc/type/test_lidar.py
+++ b/dimos/robot/unitree_webrtc/type/test_lidar.py
@@ -15,7 +15,8 @@
import itertools
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
+from dimos.robot.unitree_webrtc.type.lidar import pointcloud2_from_webrtc_lidar
from dimos.utils.testing import SensorReplay
@@ -24,5 +25,5 @@ def test_init() -> None:
for raw_frame in itertools.islice(lidar.iterate(), 5):
assert isinstance(raw_frame, dict)
- frame = LidarMessage.from_msg(raw_frame)
- assert isinstance(frame, LidarMessage)
+ frame = pointcloud2_from_webrtc_lidar(raw_frame)
+ assert isinstance(frame, PointCloud2)
diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py
deleted file mode 100644
index 2f8afbc743..0000000000
--- a/dimos/robot/unitree_webrtc/type/test_map.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import pytest
-
-from dimos.mapping.pointclouds.accumulators.general import _splice_cylinder
-from dimos.robot.unitree_webrtc.testing.helpers import show3d
-from dimos.robot.unitree_webrtc.testing.mock import Mock
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
-from dimos.robot.unitree_webrtc.type.map import Map
-from dimos.utils.testing import SensorReplay
-
-
-@pytest.mark.vis
-def test_costmap_vis() -> None:
- map = Map()
- map.start()
- mock = Mock("office")
- frames = list(mock.iterate())
-
- for frame in frames:
- print(frame)
- map.add_frame(frame)
-
- # Get global map and costmap
- global_map = map.to_lidar_message()
- print(f"Global map has {len(global_map.pointcloud.points)} points")
- show3d(global_map.pointcloud, title="Global Map").run()
-
-
-@pytest.mark.vis
-def test_reconstruction_with_realtime_vis() -> None:
- map = Map()
- map.start()
- mock = Mock("office")
-
- # Process frames and visualize final map
- for frame in mock.iterate():
- map.add_frame(frame)
-
- show3d(map.o3d_geometry, title="Reconstructed Map").run()
-
-
-@pytest.mark.vis
-def test_splice_vis() -> None:
- mock = Mock("test")
- target = mock.load("a")
- insert = mock.load("b")
- show3d(_splice_cylinder(target.pointcloud, insert.pointcloud, shrink=0.7)).run()
-
-
-@pytest.mark.vis
-def test_robot_vis() -> None:
- map = Map()
- map.start()
- mock = Mock("office")
-
- # Process all frames
- for frame in mock.iterate():
- map.add_frame(frame)
-
- show3d(map.o3d_geometry, title="global dynamic map test").run()
-
-
-@pytest.fixture
-def map_():
- map = Map(voxel_size=0.5)
- yield map
- map.stop()
-
-
-def test_robot_mapping(map_) -> None:
- lidar_replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg)
-
- # Mock the output streams to avoid publishing errors
- class MockStream:
- def publish(self, msg) -> None:
- pass # Do nothing
-
- map_.global_costmap = MockStream()
- map_.global_map = MockStream()
-
- # Process all frames from replay
- for frame in lidar_replay.iterate():
- map_.add_frame(frame)
-
- # Check the built map
- global_map = map_.to_lidar_message()
- pointcloud = global_map.pointcloud
-
- # Verify map has points
- assert len(pointcloud.points) > 0
- print(f"Map contains {len(pointcloud.points)} points")
diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py
index 7629644ed6..b683b76559 100644
--- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py
+++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py
@@ -14,27 +14,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from pathlib import Path
import platform
-from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped]
- ImageAnnotations,
+from dimos_lcm.foxglove_msgs.ImageAnnotations import (
+ ImageAnnotations, # type: ignore[import-untyped]
)
+from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate # type: ignore[import-untyped]
from dimos.agents.agent import llm_agent
from dimos.agents.cli.human import human_input
from dimos.agents.cli.web import web_input
from dimos.agents.ollama_agent import ollama_installed
from dimos.agents.skills.navigation import navigation_skill
+from dimos.agents.skills.person_follow import person_follow_skill
from dimos.agents.skills.speak_skill import speak_skill
from dimos.agents.spec import Provider
from dimos.agents.vlm_agent import vlm_agent
from dimos.agents.vlm_stream_tester import vlm_stream_tester
from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE
from dimos.core.blueprints import autoconnect
-from dimos.core.transport import JpegLcmTransport, JpegShmTransport, LCMTransport, pSHMTransport
+from dimos.core.transport import (
+ JpegLcmTransport,
+ JpegShmTransport,
+ LCMTransport,
+ ROSTransport,
+ pSHMTransport,
+)
from dimos.dashboard.tf_rerun_module import tf_rerun
from dimos.mapping.costmapper import cost_mapper
from dimos.mapping.voxels import voxel_mapper
+from dimos.msgs.geometry_msgs import PoseStamped
from dimos.msgs.sensor_msgs import Image, PointCloud2
from dimos.msgs.vision_msgs import Detection2DArray
from dimos.navigation.frontier_exploration import (
@@ -43,15 +53,19 @@
from dimos.navigation.replanning_a_star.module import (
replanning_a_star_planner,
)
-from dimos.perception.detection.moduleDB import ObjectDBModule, detectionDB_module
+from dimos.perception.detection.module3D import Detection3DModule, detection3d_module
+from dimos.perception.experimental.temporal_memory import temporal_memory
from dimos.perception.spatial_perception import spatial_memory
from dimos.protocol.mcp.mcp import MCPModule
from dimos.robot.foxglove_bridge import foxglove_bridge
+import dimos.robot.unitree.connection.go2 as _go2_mod
from dimos.robot.unitree.connection.go2 import GO2Connection, go2_connection
from dimos.robot.unitree_webrtc.unitree_skill_container import unitree_skills
from dimos.utils.monitoring import utilization
from dimos.web.websocket_vis.websocket_vis_module import websocket_vis
+_GO2_URDF = Path(_go2_mod.__file__).parent.parent / "go2" / "go2.urdf"
+
# Mac has some issue with high bandwidth UDP
#
# so we use pSHMTransport for color_image
@@ -77,7 +91,12 @@
go2_connection(),
linux if platform.system() == "Linux" else mac,
websocket_vis(),
- tf_rerun(), # Auto-visualize all TF transforms in Rerun
+ tf_rerun(
+ urdf_path=str(_GO2_URDF),
+ cameras=[
+ ("world/robot/camera", "camera_optical", GO2Connection.camera_info_static),
+ ],
+ ),
).global_config(n_dask_workers=4, robot_model="unitree_go2")
nav = autoconnect(
@@ -88,42 +107,51 @@
wavefront_frontier_explorer(),
).global_config(n_dask_workers=6, robot_model="unitree_go2")
+ros = nav.transports(
+ {
+ ("lidar", PointCloud2): ROSTransport("lidar", PointCloud2),
+ ("global_map", PointCloud2): ROSTransport("global_map", PointCloud2),
+ ("odom", PoseStamped): ROSTransport("odom", PoseStamped),
+ ("color_image", Image): ROSTransport("color_image", Image),
+ }
+)
+
detection = (
autoconnect(
nav,
- detectionDB_module(
+ detection3d_module(
camera_info=GO2Connection.camera_info_static,
),
)
.remappings(
[
- (ObjectDBModule, "pointcloud", "global_map"),
+ (Detection3DModule, "pointcloud", "global_map"),
]
)
.transports(
{
# Detection 3D module outputs
- ("detections", ObjectDBModule): LCMTransport(
+ ("detections", Detection3DModule): LCMTransport(
"/detector3d/detections", Detection2DArray
),
- ("annotations", ObjectDBModule): LCMTransport(
+ ("annotations", Detection3DModule): LCMTransport(
"/detector3d/annotations", ImageAnnotations
),
- # ("scene_update", ObjectDBModule): LCMTransport(
- # "/detector3d/scene_update", SceneUpdate
- # ),
- ("detected_pointcloud_0", ObjectDBModule): LCMTransport(
+ ("scene_update", Detection3DModule): LCMTransport(
+ "/detector3d/scene_update", SceneUpdate
+ ),
+ ("detected_pointcloud_0", Detection3DModule): LCMTransport(
"/detector3d/pointcloud/0", PointCloud2
),
- ("detected_pointcloud_1", ObjectDBModule): LCMTransport(
+ ("detected_pointcloud_1", Detection3DModule): LCMTransport(
"/detector3d/pointcloud/1", PointCloud2
),
- ("detected_pointcloud_2", ObjectDBModule): LCMTransport(
+ ("detected_pointcloud_2", Detection3DModule): LCMTransport(
"/detector3d/pointcloud/2", PointCloud2
),
- ("detected_image_0", ObjectDBModule): LCMTransport("/detector3d/image/0", Image),
- ("detected_image_1", ObjectDBModule): LCMTransport("/detector3d/image/1", Image),
- ("detected_image_2", ObjectDBModule): LCMTransport("/detector3d/image/2", Image),
+ ("detected_image_0", Detection3DModule): LCMTransport("/detector3d/image/0", Image),
+ ("detected_image_1", Detection3DModule): LCMTransport("/detector3d/image/1", Image),
+ ("detected_image_2", Detection3DModule): LCMTransport("/detector3d/image/2", Image),
}
)
)
@@ -157,6 +185,7 @@
_common_agentic = autoconnect(
human_input(),
navigation_skill(),
+ person_follow_skill(camera_info=GO2Connection.camera_info_static),
unitree_skills(),
web_input(),
speak_skill(),
@@ -198,3 +227,8 @@
vlm_agent(),
vlm_stream_tester(),
)
+
+temporal_memory = autoconnect(
+ agentic,
+ temporal_memory(),
+)
diff --git a/dimos/simulation/mujoco/constants.py b/dimos/simulation/mujoco/constants.py
index aca916a372..4e35011530 100644
--- a/dimos/simulation/mujoco/constants.py
+++ b/dimos/simulation/mujoco/constants.py
@@ -17,6 +17,7 @@
# Video/Camera constants
VIDEO_WIDTH = 320
VIDEO_HEIGHT = 240
+VIDEO_CAMERA_FOV = 45 # MuJoCo default FOV for head_camera (degrees)
DEPTH_CAMERA_FOV = 160
# Depth camera range/filtering constants
diff --git a/dimos/simulation/mujoco/model.py b/dimos/simulation/mujoco/model.py
index de533521da..bc309b7307 100644
--- a/dimos/simulation/mujoco/model.py
+++ b/dimos/simulation/mujoco/model.py
@@ -37,14 +37,21 @@ def _get_data_dir() -> epath.Path:
def get_assets() -> dict[str, bytes]:
data_dir = _get_data_dir()
+ assets: dict[str, bytes] = {}
+
# Assets used from https://sketchfab.com/3d-models/mersus-office-8714be387bcd406898b2615f7dae3a47
# Created by Ryan Cassidy and Coleman Costello
- assets: dict[str, bytes] = {}
mjx_env.update_assets(assets, data_dir, "*.xml")
mjx_env.update_assets(assets, data_dir / "scene_office1/textures", "*.png")
mjx_env.update_assets(assets, data_dir / "scene_office1/office_split", "*.obj")
mjx_env.update_assets(assets, mjx_env.MENAGERIE_PATH / "unitree_go1" / "assets")
mjx_env.update_assets(assets, mjx_env.MENAGERIE_PATH / "unitree_g1" / "assets")
+
+ # From: https://sketchfab.com/3d-models/jeong-seun-34-42956ca979404a038b8e0d3e496160fd
+ person_dir = epath.Path(str(get_data("person")))
+ mjx_env.update_assets(assets, person_dir, "*.obj")
+ mjx_env.update_assets(assets, person_dir, "*.png")
+
return assets
@@ -106,9 +113,38 @@ def get_model_xml(robot: str, scene_xml: str) -> str:
map_elem.set("znear", "0.01")
map_elem.set("zfar", "10000")
+ _add_person_object(root)
+
return ET.tostring(root, encoding="unicode")
+def _add_person_object(root: ET.Element) -> None:
+ asset = root.find("asset")
+
+ if asset is None:
+ asset = ET.SubElement(root, "asset")
+
+ ET.SubElement(asset, "mesh", name="person_mesh", file="jeong_seun_34.obj")
+ ET.SubElement(asset, "texture", name="person_texture", file="material_0.png", type="2d")
+ ET.SubElement(asset, "material", name="person_material", texture="person_texture")
+
+ worldbody = root.find("worldbody")
+
+ if worldbody is None:
+ worldbody = ET.SubElement(root, "worldbody")
+
+ person_body = ET.SubElement(worldbody, "body", name="person", pos="0 0 0", mocap="true")
+
+ ET.SubElement(
+ person_body,
+ "geom",
+ type="mesh",
+ mesh="person_mesh",
+ material="person_material",
+ euler="1.5708 0 0",
+ )
+
+
def load_scene_xml(config: GlobalConfig) -> str:
if config.mujoco_room_from_occupancy:
path = Path(config.mujoco_room_from_occupancy)
diff --git a/dimos/simulation/mujoco/mujoco_process.py b/dimos/simulation/mujoco/mujoco_process.py
index 2363a8abd3..f3e6eba279 100755
--- a/dimos/simulation/mujoco/mujoco_process.py
+++ b/dimos/simulation/mujoco/mujoco_process.py
@@ -29,8 +29,7 @@
import open3d as o3d # type: ignore[import-untyped]
from dimos.core.global_config import GlobalConfig
-from dimos.msgs.geometry_msgs import Vector3
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
from dimos.simulation.mujoco.constants import (
DEPTH_CAMERA_FOV,
LIDAR_FPS,
@@ -41,6 +40,7 @@
)
from dimos.simulation.mujoco.depth_camera import depth_image_to_point_cloud
from dimos.simulation.mujoco.model import load_model, load_scene_xml
+from dimos.simulation.mujoco.person_on_track import PersonPositionController
from dimos.simulation.mujoco.shared_memory import ShmReader
from dimos.utils.logging_config import setup_logger
@@ -97,6 +97,9 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None:
camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera")
lidar_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_front_camera")
+
+ person_position_controller = PersonPositionController(model)
+
lidar_left_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_left_camera")
lidar_right_camera_id = mujoco.mj_name2id(
model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_right_camera"
@@ -138,6 +141,8 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None:
for _ in range(config.mujoco_steps_per_frame):
mujoco.mj_step(model, data)
+ person_position_controller.tick(data)
+
m_viewer.sync()
# Always update odometry
@@ -205,11 +210,10 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None:
pcd.points = o3d.utility.Vector3dVector(combined_points)
pcd = pcd.voxel_down_sample(voxel_size=LIDAR_RESOLUTION)
- lidar_msg = LidarMessage(
+ lidar_msg = PointCloud2(
pointcloud=pcd,
ts=time.time(),
- origin=Vector3(pos[0], pos[1], pos[2]),
- resolution=LIDAR_RESOLUTION,
+ frame_id="world",
)
shm.write_lidar(lidar_msg)
@@ -220,6 +224,8 @@ def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None:
if time_until_next_step > 0:
time.sleep(time_until_next_step)
+ person_position_controller.stop()
+
if __name__ == "__main__":
diff --git a/dimos/simulation/mujoco/person_on_track.py b/dimos/simulation/mujoco/person_on_track.py
new file mode 100644
index 0000000000..a816b5f3ee
--- /dev/null
+++ b/dimos/simulation/mujoco/person_on_track.py
@@ -0,0 +1,160 @@
+# Copyright 2026 Dimensional Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import mujoco
+import numpy as np
+from numpy.typing import NDArray
+
+from dimos.core.transport import LCMTransport
+from dimos.msgs.geometry_msgs import Pose
+
+
+class PersonPositionController:
+ """Controls the person position in MuJoCo by subscribing to LCM pose updates."""
+
+ def __init__(self, model: mujoco.MjModel) -> None:
+ person_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, "person")
+ self._person_mocap_id = model.body_mocapid[person_body_id]
+ self._latest_pose: Pose | None = None
+ self._transport: LCMTransport[Pose] = LCMTransport("/person_pose", Pose)
+ self._transport.subscribe(self._on_pose)
+
+ def _on_pose(self, pose: Pose) -> None:
+ self._latest_pose = pose
+
+ def tick(self, data: mujoco.MjData) -> None:
+ if self._latest_pose is None:
+ return
+
+ pose = self._latest_pose
+ data.mocap_pos[self._person_mocap_id][0] = pose.position.x
+ data.mocap_pos[self._person_mocap_id][1] = pose.position.y
+ data.mocap_pos[self._person_mocap_id][2] = pose.position.z
+ data.mocap_quat[self._person_mocap_id] = [
+ pose.orientation.w,
+ pose.orientation.x,
+ pose.orientation.y,
+ pose.orientation.z,
+ ]
+
+ def stop(self) -> None:
+ self._transport.stop()
+
+
+class PersonTrackPublisher:
+ """Publishes person poses along a track via LCM."""
+
+ def __init__(self, track: list[tuple[float, float]]) -> None:
+ self._speed = 0.004
+ self._waypoint_threshold = 0.1
+ self._rotation_radius = 1.0
+ self._track = track
+ self._current_waypoint_idx = 0
+ self._initialized = False
+ self._current_pos = np.array([0.0, 0.0])
+ self._transport: LCMTransport[Pose] = LCMTransport("/person_pose", Pose)
+
+ def _get_segment_heading(self, from_idx: int, to_idx: int) -> float:
+ """Get heading angle for traveling from one waypoint to another."""
+ from_wp = np.array(self._track[from_idx])
+ to_wp = np.array(self._track[to_idx])
+ direction = to_wp - from_wp
+ return float(np.arctan2(direction[1], direction[0]))
+
+ def _lerp_angle(self, a1: float, a2: float, t: float) -> float:
+ """Interpolate between two angles, handling wrapping."""
+ diff = a2 - a1
+ while diff > np.pi:
+ diff -= 2 * np.pi
+ while diff < -np.pi:
+ diff += 2 * np.pi
+ return a1 + diff * t
+
+ def tick(self) -> None:
+ if not self._initialized:
+ first_point = self._track[0]
+ self._current_pos = np.array([first_point[0], first_point[1]])
+ self._current_waypoint_idx = 1
+ heading = self._get_segment_heading(0, 1)
+ self._publish_pose(self._current_pos, heading)
+ self._initialized = True
+ return
+
+ n = len(self._track)
+
+ prev_idx = (self._current_waypoint_idx - 1) % n
+ curr_idx = self._current_waypoint_idx
+ next_idx = (self._current_waypoint_idx + 1) % n
+ prev_prev_idx = (prev_idx - 1) % n
+
+ prev_wp = np.array(self._track[prev_idx])
+ curr_wp = np.array(self._track[curr_idx])
+
+ to_target = curr_wp - self._current_pos
+ distance_to_curr = float(np.linalg.norm(to_target))
+ distance_from_prev = float(np.linalg.norm(self._current_pos - prev_wp))
+
+ # Headings for current turn (at curr_wp)
+ incoming_heading = self._get_segment_heading(prev_idx, curr_idx)
+ outgoing_heading = self._get_segment_heading(curr_idx, next_idx)
+
+ # Headings for previous turn (at prev_wp)
+ prev_incoming_heading = self._get_segment_heading(prev_prev_idx, prev_idx)
+ prev_outgoing_heading = incoming_heading
+
+ # Determine heading based on position in rotation zones
+ in_leaving_zone = distance_from_prev < self._rotation_radius
+ in_approaching_zone = distance_to_curr < self._rotation_radius
+
+ if in_leaving_zone and in_approaching_zone:
+ # Overlap - prioritize approaching zone
+ t = 0.5 * (1.0 - distance_to_curr / self._rotation_radius)
+ heading = self._lerp_angle(incoming_heading, outgoing_heading, t)
+ elif in_leaving_zone:
+ # Finishing turn after passing prev_wp (t goes from 0.5 to 1.0)
+ t = 0.5 + 0.5 * (distance_from_prev / self._rotation_radius)
+ heading = self._lerp_angle(prev_incoming_heading, prev_outgoing_heading, t)
+ elif in_approaching_zone:
+ # Starting turn before reaching curr_wp (t goes from 0.0 to 0.5)
+ t = 0.5 * (1.0 - distance_to_curr / self._rotation_radius)
+ heading = self._lerp_angle(incoming_heading, outgoing_heading, t)
+ else:
+ # Between zones, use segment heading
+ heading = incoming_heading
+
+ # Move toward target
+ if distance_to_curr > 0:
+ dir_norm = to_target / distance_to_curr
+ self._current_pos[0] += dir_norm[0] * self._speed
+ self._current_pos[1] += dir_norm[1] * self._speed
+
+ # Check if reached waypoint
+ if distance_to_curr < self._waypoint_threshold:
+ self._current_waypoint_idx = next_idx
+
+ # Publish pose
+ self._publish_pose(self._current_pos, heading + np.pi)
+
+ def _publish_pose(self, pos: NDArray[np.floating[Any]], heading: float) -> None:
+ c, s = np.cos(heading / 2), np.sin(heading / 2)
+ pose = Pose(
+ position=[pos[0], pos[1], 0.0],
+ orientation=[0.0, 0.0, s, c], # x, y, z, w
+ )
+ self._transport.broadcast(None, pose)
+
+ def stop(self) -> None:
+ self._transport.stop()
diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py
index 00491b4379..212c7ac60a 100644
--- a/dimos/simulation/mujoco/policy.py
+++ b/dimos/simulation/mujoco/policy.py
@@ -20,9 +20,12 @@
import mujoco
import numpy as np
-import onnxruntime as rt # type: ignore[import-untyped]
+import onnxruntime as ort # type: ignore[import-untyped]
from dimos.simulation.mujoco.input_controller import InputController
+from dimos.utils.logging_config import setup_logger
+
+logger = setup_logger()
class OnnxController(ABC):
@@ -37,7 +40,8 @@ def __init__(
drift_compensation: list[float] | None = None,
) -> None:
self._output_names = ["continuous_actions"]
- self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"])
+ self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers())
+ logger.info(f"Loaded policy: {policy_path} with providers: {self._policy.get_providers()}")
self._action_scale = action_scale
self._default_angles = default_angles
diff --git a/dimos/simulation/mujoco/shared_memory.py b/dimos/simulation/mujoco/shared_memory.py
index 4c22062233..70ba50af2b 100644
--- a/dimos/simulation/mujoco/shared_memory.py
+++ b/dimos/simulation/mujoco/shared_memory.py
@@ -21,7 +21,7 @@
import numpy as np
from numpy.typing import NDArray
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
from dimos.simulation.mujoco.constants import VIDEO_HEIGHT, VIDEO_WIDTH
from dimos.utils.logging_config import setup_logger
@@ -148,7 +148,7 @@ def write_odom(self, pos: NDArray[Any], quat: NDArray[Any], timestamp: float) ->
odom_array[7] = timestamp
self._increment_seq(2)
- def write_lidar(self, lidar_msg: LidarMessage) -> None:
+ def write_lidar(self, lidar_msg: PointCloud2) -> None:
data = pickle.dumps(lidar_msg)
data_len = len(data)
@@ -242,7 +242,7 @@ def write_command(self, linear: NDArray[Any], angular: NDArray[Any]) -> None:
cmd_array[3:6] = angular
self._increment_seq(3)
- def read_lidar(self) -> tuple[LidarMessage | None, int]:
+ def read_lidar(self) -> tuple[PointCloud2 | None, int]:
seq = self._get_seq(4)
if seq > 0:
# Read length
diff --git a/dimos/stream/video_operators.py b/dimos/stream/video_operators.py
index 558972e155..548bba7598 100644
--- a/dimos/stream/video_operators.py
+++ b/dimos/stream/video_operators.py
@@ -16,7 +16,7 @@
from collections.abc import Callable
from datetime import datetime, timedelta
from enum import Enum
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
import cv2
import numpy as np
diff --git a/dimos/types/test_weaklist.py b/dimos/types/test_weaklist.py
index 990cc0d164..06f9f851ce 100644
--- a/dimos/types/test_weaklist.py
+++ b/dimos/types/test_weaklist.py
@@ -54,6 +54,7 @@ def test_weaklist_basic_operations() -> None:
assert SampleObject(4) not in wl
+@pytest.mark.integration
def test_weaklist_auto_removal() -> None:
"""Test that objects are automatically removed when garbage collected."""
wl = WeakList()
@@ -136,6 +137,7 @@ def test_weaklist_clear() -> None:
assert obj1 not in wl
+@pytest.mark.integration
def test_weaklist_iteration_during_modification() -> None:
"""Test that iteration works even if objects are deleted during iteration."""
wl = WeakList()
diff --git a/dimos/utils/cli/boxglove/boxglove.py b/dimos/utils/cli/boxglove/boxglove.py
deleted file mode 100644
index 3ace1c1aaa..0000000000
--- a/dimos/utils/cli/boxglove/boxglove.py
+++ /dev/null
@@ -1,292 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-import numpy as np
-import reactivex.operators as ops
-from rich.text import Text
-from textual.app import App, ComposeResult
-from textual.containers import Container
-from textual.reactive import reactive
-from textual.widgets import Footer, Static
-
-from dimos import core
-from dimos.msgs.nav_msgs import OccupancyGrid
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
-
-if TYPE_CHECKING:
- from reactivex.disposable import Disposable
-
- from dimos.msgs.nav_msgs import OccupancyGrid
- from dimos.utils.cli.boxglove.connection import Connection
-
-
-blocks = "█▗▖▝▘"
-shades = "█░░░░"
-crosses = "┼┌┐└┘"
-quadrant = "█▟▙▜▛"
-triangles = "◼◢◣◥◤" # 45-degree triangular blocks
-
-
-alphabet = crosses
-
-# Box drawing characters for smooth edges
-top_left = alphabet[1] # Quadrant lower right
-top_right = alphabet[2] # Quadrant lower left
-bottom_left = alphabet[3] # Quadrant upper right
-bottom_right = alphabet[4] # Quadrant upper left
-full = alphabet[0] # Full block
-
-
-class OccupancyGridApp(App): # type: ignore[type-arg]
- """A Textual app for visualizing OccupancyGrid data in real-time."""
-
- CSS = """
- Screen {
- layout: vertical;
- overflow: hidden;
- }
-
- #grid-container {
- width: 100%;
- height: 1fr;
- overflow: hidden;
- margin: 0;
- padding: 0;
- }
-
- #grid-display {
- width: 100%;
- height: 100%;
- margin: 0;
- padding: 0;
- }
-
- Footer {
- dock: bottom;
- height: 1;
- }
- """
-
- # Reactive properties
- grid_data: reactive[OccupancyGrid | None] = reactive(None)
-
- BINDINGS = [
- ("q", "quit", "Quit"),
- ("ctrl+c", "quit", "Quit"),
- ]
-
- def __init__(self, connection: Connection, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
- super().__init__(*args, **kwargs)
- self.connection = connection
- self.subscription: Disposable | None = None
- self.grid_display: Static | None = None
- self.cached_grid: OccupancyGrid | None = None
-
- def compose(self) -> ComposeResult:
- """Create the app layout."""
- # Container for the grid (no scrolling since we scale to fit)
- with Container(id="grid-container"):
- self.grid_display = Static("", id="grid-display")
- yield self.grid_display
-
- yield Footer()
-
- def on_mount(self) -> None:
- """Subscribe to the connection when the app starts."""
- self.theme = "flexoki"
-
- # Subscribe to the OccupancyGrid stream
- def on_grid(grid: OccupancyGrid) -> None:
- self.grid_data = grid
-
- def on_error(error: Exception) -> None:
- self.notify(f"Error: {error}", severity="error")
-
- self.subscription = self.connection().subscribe(on_next=on_grid, on_error=on_error) # type: ignore[assignment]
-
- async def on_unmount(self) -> None:
- """Clean up subscription when app closes."""
- if self.subscription:
- self.subscription.dispose()
-
- def watch_grid_data(self, grid: OccupancyGrid | None) -> None:
- """Update display when new grid data arrives."""
- if grid is None:
- return
-
- # Cache the grid for rerendering on terminal resize
- self.cached_grid = grid
-
- # Render the grid as ASCII art
- grid_text = self.render_grid(grid)
- self.grid_display.update(grid_text) # type: ignore[union-attr]
-
- def on_resize(self, event) -> None: # type: ignore[no-untyped-def]
- """Handle terminal resize events."""
- if self.cached_grid:
- # Re-render with new terminal dimensions
- grid_text = self.render_grid(self.cached_grid)
- self.grid_display.update(grid_text) # type: ignore[union-attr]
-
- def render_grid(self, grid: OccupancyGrid) -> Text:
- """Render the OccupancyGrid as colored ASCII art, scaled to fit terminal."""
- text = Text()
-
- # Get the actual container dimensions
- container = self.query_one("#grid-container")
- content_width = container.content_size.width
- content_height = container.content_size.height
-
- # Each cell will be 2 chars wide to make square pixels
- terminal_width = max(1, content_width // 2)
- terminal_height = max(1, content_height)
-
- # Handle edge cases
- if grid.width == 0 or grid.height == 0:
- return text # Return empty text for empty grid
-
- # Calculate scaling factors (as floats for smoother scaling)
- scale_x = grid.width / terminal_width
- scale_y = grid.height / terminal_height
-
- # Use the larger scale to ensure the grid fits
- scale_float = max(1.0, max(scale_x, scale_y))
-
- # For smoother resizing, we'll use fractional scaling
- # This means we might sample between grid cells
- render_width = min(int(grid.width / scale_float), terminal_width)
- render_height = min(int(grid.height / scale_float), terminal_height)
-
- # Store both integer and float scale for different uses
- int(np.ceil(scale_float)) # For legacy compatibility
-
- # Adjust render dimensions to use all available space
- # This reduces jumping by allowing fractional cell sizes
- actual_scale_x = grid.width / render_width if render_width > 0 else 1
- actual_scale_y = grid.height / render_height if render_height > 0 else 1
-
- # Function to get value with fractional scaling
- def get_cell_value(grid_data: np.ndarray, x: int, y: int) -> int: # type: ignore[type-arg]
- # Use fractional coordinates for smoother scaling
- y_center = int((y + 0.5) * actual_scale_y)
- x_center = int((x + 0.5) * actual_scale_x)
-
- # Clamp to grid bounds
- y_center = max(0, min(y_center, grid.height - 1))
- x_center = max(0, min(x_center, grid.width - 1))
-
- # For now, just sample the center point
- # Could do area averaging for smoother results
- return grid_data[y_center, x_center] # type: ignore[no-any-return]
-
- # Helper function to check if a cell is an obstacle
- def is_obstacle(grid_data: np.ndarray, x: int, y: int) -> bool: # type: ignore[type-arg]
- if x < 0 or x >= render_width or y < 0 or y >= render_height:
- return False
- value = get_cell_value(grid_data, x, y)
- return value > 90 # Consider cells with >90% probability as obstacles
-
- # Character and color mapping with intelligent obstacle rendering
- def get_cell_char_and_style(grid_data: np.ndarray, x: int, y: int) -> tuple[str, str]: # type: ignore[type-arg]
- value = get_cell_value(grid_data, x, y)
- norm_value = min(value, 100) / 100.0
-
- if norm_value > 0.9:
- # Check neighbors for intelligent character selection
- top = is_obstacle(grid_data, x, y + 1)
- bottom = is_obstacle(grid_data, x, y - 1)
- left = is_obstacle(grid_data, x - 1, y)
- right = is_obstacle(grid_data, x + 1, y)
-
- # Count neighbors
- neighbor_count = sum([top, bottom, left, right])
-
- # Select character based on neighbor configuration
- if neighbor_count == 4:
- # All neighbors are obstacles - use full block
- symbol = full + full
- elif neighbor_count == 3:
- # Three neighbors - use full block (interior edge)
- symbol = full + full
- elif neighbor_count == 2:
- # Two neighbors - check configuration
- if top and bottom:
- symbol = full + full # Vertical corridor
- elif left and right:
- symbol = full + full # Horizontal corridor
- elif top and left:
- symbol = bottom_right + " "
- elif top and right:
- symbol = " " + bottom_left
- elif bottom and left:
- symbol = top_right + " "
- elif bottom and right:
- symbol = " " + top_left
- else:
- symbol = full + full
- elif neighbor_count == 1:
- # One neighbor - point towards it
- if top:
- symbol = bottom_left + bottom_right
- elif bottom:
- symbol = top_left + top_right
- elif left:
- symbol = top_right + bottom_right
- elif right:
- symbol = top_left + bottom_left
- else:
- symbol = full + full
- else:
- # No neighbors - isolated obstacle
- symbol = full + full
-
- return symbol, None # type: ignore[return-value]
- else:
- return " ", None # type: ignore[return-value]
-
- # Render the scaled grid row by row (flip Y axis for proper display)
- for y in range(render_height - 1, -1, -1):
- for x in range(render_width):
- char, style = get_cell_char_and_style(grid.grid, x, y)
- text.append(char, style=style)
- if y > 0: # Add newline except for last row
- text.append("\n")
-
- # Could show scale info in footer status if needed
-
- return text
-
-
-def main() -> None:
- """Run the OccupancyGrid visualizer with a connection."""
- # app = OccupancyGridApp(core.LCMTransport("/global_costmap", OccupancyGrid).observable)
-
- app = OccupancyGridApp(
- lambda: core.LCMTransport("/lidar", LidarMessage) # type: ignore[no-untyped-call]
- .observable()
- .pipe(ops.map(lambda msg: msg.costmap())) # type: ignore[attr-defined]
- )
- app.run()
- import time
-
- while True:
- time.sleep(1)
-
-
-if __name__ == "__main__":
- main()
diff --git a/dimos/utils/cli/boxglove/connection.py b/dimos/utils/cli/boxglove/connection.py
deleted file mode 100644
index 1743684626..0000000000
--- a/dimos/utils/cli/boxglove/connection.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2025-2026 Dimensional Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections.abc import Callable
-import pickle
-
-import reactivex as rx
-from reactivex import operators as ops
-from reactivex.disposable import Disposable
-from reactivex.observable import Observable
-
-from dimos.msgs.nav_msgs import OccupancyGrid
-from dimos.msgs.sensor_msgs import PointCloud2
-from dimos.protocol.pubsub import lcm # type: ignore[attr-defined]
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
-from dimos.robot.unitree_webrtc.type.map import Map
-from dimos.utils.data import get_data
-from dimos.utils.reactive import backpressure
-from dimos.utils.testing import TimedSensorReplay
-
-Connection = Callable[[], Observable[OccupancyGrid]]
-
-
-def live_connection() -> Observable[OccupancyGrid]:
- def subscribe(observer, scheduler=None): # type: ignore[no-untyped-def]
- lcm.autoconf()
- l = lcm.LCM()
-
- def on_message(grid: OccupancyGrid, _) -> None: # type: ignore[no-untyped-def]
- observer.on_next(grid)
-
- l.subscribe(lcm.Topic("/global_costmap", OccupancyGrid), on_message)
- l.start()
-
- def dispose() -> None:
- l.stop()
-
- return Disposable(dispose)
-
- return rx.create(subscribe)
-
-
-def recorded_connection() -> Observable[OccupancyGrid]:
- lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg)
- mapper = Map()
- return backpressure(
- lidar_store.stream(speed=1).pipe(
- ops.map(mapper.add_frame),
- ops.map(lambda _: mapper.costmap().inflate(0.1).gradient()), # type: ignore[attr-defined]
- )
- )
-
-
-def single_message() -> Observable[OccupancyGrid]:
- pointcloud_pickle = get_data("lcm_msgs") / "sensor_msgs/PointCloud2.pickle"
- with open(pointcloud_pickle, "rb") as f:
- pointcloud = PointCloud2.lcm_decode(pickle.load(f))
- mapper = Map()
- mapper.add_frame(pointcloud)
- return rx.just(mapper.costmap()) # type: ignore[attr-defined]
diff --git a/dimos/utils/cli/human/humancli.py b/dimos/utils/cli/human/humancli.py
index a0ce0afff4..cd7ca91637 100644
--- a/dimos/utils/cli/human/humancli.py
+++ b/dimos/utils/cli/human/humancli.py
@@ -15,6 +15,7 @@
from __future__ import annotations
from datetime import datetime
+import json
import textwrap
import threading
from typing import TYPE_CHECKING
@@ -140,7 +141,9 @@ def receive_msg(msg) -> None: # type: ignore[no-untyped-def]
)
elif isinstance(msg, AIMessage):
content = msg.content or ""
- tool_calls = msg.additional_kwargs.get("tool_calls", [])
+ tool_calls = getattr(msg, "tool_calls", None) or msg.additional_kwargs.get(
+ "tool_calls", []
+ )
# Display the main content first
if content:
@@ -174,9 +177,10 @@ def receive_msg(msg) -> None: # type: ignore[no-untyped-def]
def _format_tool_call(self, tool_call: ToolCall) -> str:
"""Format a tool call for display."""
- f = tool_call.get("function", {})
- name = f.get("name", "unknown") # type: ignore[attr-defined]
- return f"▶ {name}({f.get('arguments', '')})" # type: ignore[attr-defined]
+ name = tool_call.get("name", "unknown")
+ args = tool_call.get("args", {})
+ args_str = json.dumps(args, separators=(",", ":"))
+ return f"▶ {name}({args_str})"
def _add_message(self, timestamp: str, sender: str, content: str, color: str) -> None:
"""Add a message to the chat log."""
diff --git a/dimos/utils/test_reactive.py b/dimos/utils/test_reactive.py
index a0f3fe42ef..5bfc0a590f 100644
--- a/dimos/utils/test_reactive.py
+++ b/dimos/utils/test_reactive.py
@@ -82,6 +82,7 @@ def _dispose() -> None:
return proxy
+@pytest.mark.integration
def test_backpressure_handling() -> None:
# Create a dedicated scheduler for this test to avoid thread leaks
test_scheduler = ThreadPoolScheduler(max_workers=8)
@@ -141,6 +142,7 @@ def test_backpressure_handling() -> None:
test_scheduler.executor.shutdown(wait=True)
+@pytest.mark.integration
def test_getter_streaming_blocking() -> None:
source = dispose_spy(
rx.interval(0.2).pipe(ops.map(lambda i: np.array([i, i + 1, i + 2])), ops.take(50))
@@ -175,6 +177,7 @@ def test_getter_streaming_blocking_timeout() -> None:
assert source.is_disposed()
+@pytest.mark.integration
def test_getter_streaming_nonblocking() -> None:
source = dispose_spy(rx.interval(0.2).pipe(ops.take(50)))
diff --git a/dimos/utils/testing/replay.py b/dimos/utils/testing/replay.py
index e9b69b6ecd..89225c322e 100644
--- a/dimos/utils/testing/replay.py
+++ b/dimos/utils/testing/replay.py
@@ -40,7 +40,7 @@ class SensorReplay(Generic[T]):
Args:
name: The name of the test dataset
autocast: Optional function that takes unpickled data and returns a processed result.
- For example: lambda data: LidarMessage.from_msg(data)
+ For example: pointcloud2_from_webrtc_lidar
"""
def __init__(self, name: str, autocast: Callable[[Any], T] | None = None) -> None:
diff --git a/dimos/utils/testing/test_replay.py b/dimos/utils/testing/test_replay.py
index 44b6a232c8..640fe92979 100644
--- a/dimos/utils/testing/test_replay.py
+++ b/dimos/utils/testing/test_replay.py
@@ -16,7 +16,8 @@
from reactivex import operators as ops
-from dimos.robot.unitree_webrtc.type.lidar import LidarMessage
+from dimos.msgs.sensor_msgs import PointCloud2
+from dimos.robot.unitree_webrtc.type.lidar import pointcloud2_from_webrtc_lidar
from dimos.robot.unitree_webrtc.type.odometry import Odometry
from dimos.utils.data import get_data
from dimos.utils.testing import replay
@@ -33,10 +34,10 @@ def test_sensor_replay() -> None:
def test_sensor_replay_cast() -> None:
counter = 0
for message in replay.SensorReplay(
- name="office_lidar", autocast=LidarMessage.from_msg
+ name="office_lidar", autocast=pointcloud2_from_webrtc_lidar
).iterate():
counter += 1
- assert isinstance(message, LidarMessage)
+ assert isinstance(message, PointCloud2)
assert counter == 500
@@ -204,7 +205,7 @@ def test_first_methods() -> None:
"""Test first() and first_timestamp() methods"""
# Test SensorReplay.first()
- lidar_replay = replay.SensorReplay("office_lidar", autocast=LidarMessage.from_msg)
+ lidar_replay = replay.SensorReplay("office_lidar", autocast=pointcloud2_from_webrtc_lidar)
print("first file", lidar_replay.files[0])
# Verify the first file ends with 000.pickle using regex
@@ -214,13 +215,13 @@ def test_first_methods() -> None:
first_msg = lidar_replay.first()
assert first_msg is not None
- assert isinstance(first_msg, LidarMessage)
+ assert isinstance(first_msg, PointCloud2)
# Verify it's the same type as first item from iterate()
first_from_iterate = next(lidar_replay.iterate())
print("DONE")
assert type(first_msg) is type(first_from_iterate)
- # Since LidarMessage.from_msg uses time.time(), timestamps will be slightly different
+ # Since pointcloud2_from_webrtc_lidar uses time.time(), timestamps will be slightly different
assert abs(first_msg.ts - first_from_iterate.ts) < 1.0 # Within 1 second tolerance
# Test TimedSensorReplay.first_timestamp()
diff --git a/dimos/web/templates/rerun_dashboard.html b/dimos/web/templates/rerun_dashboard.html
index 9917d9d2af..f0792079e3 100644
--- a/dimos/web/templates/rerun_dashboard.html
+++ b/dimos/web/templates/rerun_dashboard.html
@@ -6,15 +6,78 @@
* { margin: 0; padding: 0; box-sizing: border-box; }
html, body { width: 100%; height: 100%; overflow: hidden; }
body { background: #0a0a0f; font-family: -apple-system, system-ui, sans-serif; }
+ :root { --command-center-width: max(30vw, 35rem); }
.container { display: flex; width: 100%; height: 100%; }
- .rerun { flex: 1; border: none; }
- .command-center { width: 400px; border: none; border-left: 1px solid #333; }
+ .command-center {
+ width: var(--command-center-width);
+ min-width: 16rem;
+ border: none;
+ border-right: 1px solid #333;
+ }
+ .rerun { flex: 1 1 auto; border: none; min-width: 0; }
+ .divider {
+ width: 6px;
+ background: linear-gradient(180deg, #202530 0%, #141824 100%);
+ cursor: col-resize;
+ border-left: 1px solid #0f1016;
+ border-right: 1px solid #0f1016;
+ }
+ .divider:hover { background: #2a3140; }
+ .divider.dragging { background: #3a4458; }
+ body.dragging { user-select: none; cursor: col-resize; }
+