Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dimos/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __getattr__(self, name: str):

if name in self.rpcs:
return lambda *args, **kwargs: self.rpc.call_sync(
f"{self.remote_name}/{name}", (args, kwargs), timeout=2.0
f"{self.remote_name}/{name}", (args, kwargs)
)

# return super().__getattr__(name)
Expand Down
177 changes: 177 additions & 0 deletions dimos/core/test_rpcstress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2025 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time

from dimos.core import In, Module, Out, rpc


class Counter(Module):
current_count: int = 0

count_stream: Out[int] = None

def __init__(self):
super().__init__()
self.current_count = 0

@rpc
def increment(self):
"""Increment the counter and publish the new value."""
self.current_count += 1
self.count_stream.publish(self.current_count)
return self.current_count


class CounterValidator(Module):
"""Calls counter.increment() as fast as possible and validates no numbers are skipped."""

count_in: In[int] = None

def __init__(self, increment_func):
super().__init__()
self.increment_func = increment_func
self.last_seen = 0
self.missing_numbers = []
self.running = False
self.call_thread = None
self.call_count = 0
self.total_latency = 0.0
self.call_start_time = None
self.waiting_for_response = False

@rpc
def start(self):
"""Start the validator."""
self.count_in.subscribe(self._on_count_received)
self.running = True
self.call_thread = threading.Thread(target=self._call_loop)
self.call_thread.start()

@rpc
def stop(self):
"""Stop the validator."""
self.running = False
if self.call_thread:
self.call_thread.join()

def _on_count_received(self, count: int):
"""Check if we received all numbers in sequence and trigger next call."""
# Calculate round trip time
if self.call_start_time:
latency = time.time() - self.call_start_time
self.total_latency += latency

if count != self.last_seen + 1:
for missing in range(self.last_seen + 1, count):
self.missing_numbers.append(missing)
print(f"[VALIDATOR] Missing number detected: {missing}")
self.last_seen = count

# Signal that we can make the next call
self.waiting_for_response = False

def _call_loop(self):
"""Call increment only after receiving response from previous call."""
while self.running:
if not self.waiting_for_response:
try:
self.waiting_for_response = True
self.call_start_time = time.time()
result = self.increment_func()
call_time = time.time() - self.call_start_time
self.call_count += 1
if self.call_count % 100 == 0:
avg_latency = (
self.total_latency / self.call_count if self.call_count > 0 else 0
)
print(
f"[VALIDATOR] Made {self.call_count} calls, last result: {result}, RPC call time: {call_time * 1000:.2f}ms, avg RTT: {avg_latency * 1000:.2f}ms"
)
except Exception as e:
print(f"[VALIDATOR] Error calling increment: {e}")
self.waiting_for_response = False
time.sleep(0.001) # Small delay on error
else:
# Don't sleep - busy wait for maximum speed
pass

@rpc
def get_stats(self):
"""Get validation statistics."""
avg_latency = self.total_latency / self.call_count if self.call_count > 0 else 0
return {
"call_count": self.call_count,
"last_seen": self.last_seen,
"missing_count": len(self.missing_numbers),
"missing_numbers": self.missing_numbers[:10] if self.missing_numbers else [],
"avg_rtt_ms": avg_latency * 1000,
"calls_per_sec": self.call_count / 10.0 if self.call_count > 0 else 0,
}


if __name__ == "__main__":
import dimos.core as core
from dimos.core import pLCMTransport

# Start dimos with 2 workers
client = core.start(2)

# Deploy counter module
counter = client.deploy(Counter)
counter.count_stream.transport = pLCMTransport("/counter_stream")

# Deploy validator module with increment function
validator = client.deploy(CounterValidator, counter.increment)
validator.count_in.transport = pLCMTransport("/counter_stream")

# Connect validator to counter's output
validator.count_in.connect(counter.count_stream)

# Start modules
validator.start()

print("[MAIN] Counter and validator started. Running for 10 seconds...")

# Test direct RPC speed for comparison
print("\n[MAIN] Testing direct RPC call speed for 1 second...")
start = time.time()
direct_count = 0
while time.time() - start < 1.0:
counter.increment()
direct_count += 1
print(f"[MAIN] Direct RPC calls per second: {direct_count}")

# Run for 10 seconds
time.sleep(10)

# Get stats before stopping
stats = validator.get_stats()
print(f"\n[MAIN] Final statistics:")
print(f" - Total calls made: {stats['call_count']}")
print(f" - Last number seen: {stats['last_seen']}")
print(f" - Missing numbers: {stats['missing_count']}")
print(f" - Average RTT: {stats['avg_rtt_ms']:.2f}ms")
print(f" - Calls per second: {stats['calls_per_sec']:.1f}")
if stats["missing_numbers"]:
print(f" - First missing numbers: {stats['missing_numbers']}")

# Stop modules
validator.stop()

# Shutdown dimos
client.shutdown()

print("[MAIN] Test complete.")
11 changes: 11 additions & 0 deletions dimos/msgs/sensor_msgs/Image.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ImageFormat(Enum):
GRAY = "GRAY" # 8-bit Grayscale
GRAY16 = "GRAY16" # 16-bit Grayscale
DEPTH = "DEPTH" # 32-bit Float Depth
DEPTH16 = "DEPTH16" # 16-bit Integer Depth (millimeters)


@dataclass
Expand Down Expand Up @@ -169,6 +170,8 @@ def to_opencv(self) -> np.ndarray:
return self.data
elif self.format == ImageFormat.DEPTH:
return self.data # Depth images are already in the correct format
elif self.format == ImageFormat.DEPTH16:
return self.data # 16-bit depth images are already in the correct format
else:
raise ValueError(f"Unsupported format conversion: {self.format}")

Expand Down Expand Up @@ -373,6 +376,11 @@ def _get_lcm_encoding(self) -> str:
return "32FC1"
elif self.dtype == np.float64:
return "64FC1"
elif self.format == ImageFormat.DEPTH16:
if self.dtype == np.uint16:
return "16UC1" # 16-bit unsigned depth
elif self.dtype == np.int16:
return "16SC1" # 16-bit signed depth

raise ValueError(
f"Cannot determine LCM encoding for format={self.format}, dtype={self.dtype}"
Expand All @@ -393,6 +401,9 @@ def _parse_encoding(encoding: str) -> dict:
"32FC1": {"format": ImageFormat.DEPTH, "dtype": np.float32, "channels": 1},
"32FC3": {"format": ImageFormat.RGB, "dtype": np.float32, "channels": 3},
"64FC1": {"format": ImageFormat.DEPTH, "dtype": np.float64, "channels": 1},
# 16-bit depth encodings
"16UC1": {"format": ImageFormat.DEPTH16, "dtype": np.uint16, "channels": 1},
"16SC1": {"format": ImageFormat.DEPTH16, "dtype": np.int16, "channels": 1},
}

if encoding not in encoding_map:
Expand Down
30 changes: 21 additions & 9 deletions dimos/navigation/bt_navigator/navigator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,28 @@ def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> Optional[PoseStamp
return goal

try:
transform = self.tf.get(
parent_frame=odom_frame,
child_frame=goal.frame_id,
time_point=goal.ts,
time_tolerance=1.0,
)
transform = None
max_retries = 3

for attempt in range(max_retries):
transform = self.tf.get(
parent_frame=odom_frame,
child_frame=goal.frame_id,
)

if transform:
break

if not transform:
logger.error(f"Could not find transform from '{goal.frame_id}' to '{odom_frame}'")
return None
if attempt < max_retries - 1:
logger.warning(
f"Transform attempt {attempt + 1}/{max_retries} failed, retrying..."
)
time.sleep(1.0)
else:
logger.error(
f"Could not find transform from '{goal.frame_id}' to '{odom_frame}' after {max_retries} attempts"
)
return None

pose = apply_transform(goal, transform)
transformed_goal = PoseStamped(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
lookahead_distance: float = 5.0,
max_explored_distance: float = 10.0,
info_gain_threshold: float = 0.03,
num_no_gain_attempts: int = 4,
num_no_gain_attempts: int = 2,
goal_timeout: float = 15.0,
**kwargs,
):
Expand Down Expand Up @@ -639,7 +639,8 @@ def get_exploration_goal(
logger.info(
f"No information gain for {self.no_gain_counter} consecutive attempts"
)
self.reset_exploration_session()
self.no_gain_counter = 0 # Reset counter when stopping due to no gain
self.stop_exploration()
return None
else:
self.no_gain_counter = 0
Expand Down Expand Up @@ -724,6 +725,7 @@ def stop_exploration(self) -> bool:
return False

self.exploration_active = False
self.no_gain_counter = 0 # Reset counter when exploration stops
self.stop_event.set()

if self.exploration_thread and self.exploration_thread.is_alive():
Expand Down
2 changes: 1 addition & 1 deletion dimos/navigation/global_planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def plan(self, goal: Pose) -> Optional[Path]:

# Get current position from odometry
robot_pos = self.latest_odom.position
costmap = self.latest_costmap.inflate(0.1).gradient(max_distance=1.0)
costmap = self.latest_costmap.inflate(0.2).gradient(max_distance=1.5)

# Run A* planning
path = astar(costmap, goal.position, robot_pos)
Expand Down
Loading