Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions dimos/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dask.distributed import Client, LocalCluster
from rich.console import Console

import signal
import dimos.core.colors as colors
from dimos.core.core import rpc
from dimos.core.module import Module, ModuleBase, ModuleConfig
Expand Down Expand Up @@ -49,6 +50,46 @@
]


class RpcCall:
def __init__(self, original_method, rpc, name, remote_name, unsub_fns, stop_client=None):
self._original_method = original_method
self._rpc = rpc
self._name = name
self._remote_name = remote_name
self._unsub_fns = unsub_fns
self._stop_client = stop_client

if original_method:
self.__doc__ = original_method.__doc__
self.__name__ = original_method.__name__
self.__qualname__ = f"{self.__class__.__name__}.{original_method.__name__}"

def __call__(self, *args, **kwargs):
if not self._rpc:
return None

# TODO: This may not be needed anymore.
#
# For stop/close/shutdown, use call_nowait to avoid deadlock
# (the remote side stops its RPC service before responding)
if self._name == "stop":
if self._rpc:
self._rpc.call_nowait(f"{self._remote_name}/{self._name}", (args, kwargs))
if self._stop_client:
self._stop_client()
return None

result, unsub_fn = self._rpc.call_sync(f"{self._remote_name}/{self._name}", (args, kwargs))
self._unsub_fns.append(unsub_fn)
return result

def __getstate__(self):
return (self._original_method, self._name, self._remote_name, self._unsub_fns)

def __setstate__(self, state):
self._original_method, self._name, self._remote_name, self._unsub_fns = state


class CudaCleanupPlugin:
"""Dask worker plugin to cleanup CUDA resources on shutdown."""

Expand Down Expand Up @@ -125,29 +166,10 @@ def __getattr__(self, name: str):
raise AttributeError(f"{name} is not found.")

if name in self.rpcs:
# Get the original method to preserve its docstring
original_method = getattr(self.actor_class, name, None)

def rpc_call(*args, **kwargs):
# For stop/close/shutdown, use call_nowait to avoid deadlock
# (the remote side stops its RPC service before responding)
if name in ("stop", "close", "shutdown"):
if self.rpc:
self.rpc.call_nowait(f"{self.remote_name}/{name}", (args, kwargs))
self.stop_client()
return None

result, unsub_fn = self.rpc.call_sync(f"{self.remote_name}/{name}", (args, kwargs))
self._unsub_fns.append(unsub_fn)
return result

# Copy docstring and other attributes from original method
if original_method:
rpc_call.__doc__ = original_method.__doc__
rpc_call.__name__ = original_method.__name__
rpc_call.__qualname__ = f"{self.__class__.__name__}.{original_method.__name__}"

return rpc_call
return RpcCall(
original_method, self.rpc, name, self.remote_name, self._unsub_fns, self.stop_client
)

# return super().__getattr__(name)
# Try to avoid recursion by directly accessing attributes that are known
Expand Down Expand Up @@ -309,8 +331,6 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client:
n: Number of workers (defaults to CPU count)
memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default)
"""
import signal
import atexit

console = Console()
if not n:
Expand Down
154 changes: 154 additions & 0 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2025 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from collections import defaultdict
from functools import cached_property
from types import MappingProxyType
from typing import Any, Mapping, get_origin, get_args

from dimos.core.dimos import Dimos
from dimos.core.module import Module
from dimos.core.stream import In, Out
from dimos.core.transport import LCMTransport, pLCMTransport
from dimos.utils.generic import short_id


@dataclass(frozen=True)
class ModuleBlueprint:
module: type[Module]
incoming: dict[str, type]
outgoing: dict[str, type]
args: tuple[Any]
kwargs: dict[str, Any]


@dataclass(frozen=True)
class ModuleBlueprintSet:
blueprints: tuple[ModuleBlueprint, ...]
# TODO: Replace Any
transports: Mapping[tuple[str, type], Any] = field(default_factory=lambda: MappingProxyType({}))

def with_transports(self, transports: dict[tuple[str, type], Any]) -> "ModuleBlueprintSet":
return ModuleBlueprintSet(
blueprints=self.blueprints,
transports=MappingProxyType({**self.transports, **transports}),
)

def _get_transport_for(self, name: str, type: type) -> Any:
transport = self.transports.get((name, type), None)
if transport:
return transport

use_pickled = "lcm_encode" not in type.__dict__
topic = f"/{name}" if self._is_name_unique(name) else f"/{short_id()}"
transport = pLCMTransport(topic) if use_pickled else LCMTransport(topic, type)

return transport

@cached_property
def _all_name_types(self) -> set[tuple[str, type]]:
all_name_types = set()
for blueprint in self.blueprints:
for name, type in blueprint.incoming.items():
all_name_types.add((name, type))
for name, type in blueprint.outgoing.items():
all_name_types.add((name, type))
return all_name_types

def _is_name_unique(self, name: str) -> bool:
return sum(1 for n, _ in self._all_name_types if n == name) == 1

def build(self, n: int | None = None) -> Dimos:
dimos = Dimos(n=n)

dimos.start()

# Deploy all modules.
for blueprint in self.blueprints:
dimos.deploy(blueprint.module, *blueprint.args, **blueprint.kwargs)

# Gather all the In/Out connections.
incoming = defaultdict(list)
outgoing = defaultdict(list)
for blueprint in self.blueprints:
for name, type in blueprint.incoming.items():
incoming[(name, type)].append(blueprint.module)
for name, type in blueprint.outgoing.items():
outgoing[(name, type)].append(blueprint.module)

# Connect all In/Out connections by name and type.
for name, type in set(incoming.keys()).union(outgoing.keys()):
transport = self._get_transport_for(name, type)
for module in incoming[(name, type)] + outgoing[(name, type)]:
instance = dimos.get_instance(module)
getattr(instance, name).transport = transport

# Gather all RPC methods.
rpc_methods = {}
for blueprint in self.blueprints:
for method_name in blueprint.module.rpcs.keys():
method = getattr(dimos.get_instance(blueprint.module), method_name)
rpc_methods[f"{blueprint.module.__name__}_{method_name}"] = method

# Fulfil method requests (so modules can call each other).
for blueprint in self.blueprints:
for method_name, method in blueprint.module.rpcs.items():
if not method_name.startswith("set_"):
continue
linked_name = method_name.removeprefix("set_")
if linked_name not in rpc_methods:
continue
instance = dimos.get_instance(blueprint.module)
getattr(instance, method_name)(rpc_methods[linked_name])

dimos.start_all_modules()

return dimos


def make_module_blueprint(
module: type[Module], args: tuple[Any], kwargs: dict[str, Any]
) -> ModuleBlueprint:
incoming: dict[str, type] = {}
outgoing: dict[str, type] = {}

all_annotations = {}
for base_class in reversed(module.__mro__):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)

for name, annotation in all_annotations.items():
origin = get_origin(annotation)
if origin not in (In, Out):
continue
dict_ = incoming if origin == In else outgoing
dict_[name] = get_args(annotation)[0]

return ModuleBlueprint(
module=module, incoming=incoming, outgoing=outgoing, args=args, kwargs=kwargs
)


def create_module_blueprint(module: type[Module], *args: Any, **kwargs: Any) -> ModuleBlueprintSet:
blueprint = make_module_blueprint(module, args, kwargs)
return ModuleBlueprintSet(blueprints=(blueprint,))


def autoconnect(*blueprints: ModuleBlueprintSet) -> ModuleBlueprintSet:
all_blueprints = tuple(bp for bs in blueprints for bp in bs.blueprints)
all_transports = dict(sum([list(x.transports.items()) for x in blueprints], []))
return ModuleBlueprintSet(
blueprints=all_blueprints, transports=MappingProxyType(all_transports)
)
40 changes: 40 additions & 0 deletions dimos/core/test_blueprints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2025 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dimos.core.blueprints import ModuleBlueprint, make_module_blueprint
from dimos.core.module import Module
from dimos.core.stream import In, Out


class Scratch:
pass


class Petting:
pass


class CatModule(Module):
pet_cat: In[Petting]
scratches: Out[Scratch]


def test_get_connection_set():
assert make_module_blueprint(CatModule, args=(), kwargs={}) == ModuleBlueprint(
module=CatModule,
incoming={"pet_cat": Petting},
outgoing={"scratches": Scratch},
args=(),
kwargs={},
)
15 changes: 15 additions & 0 deletions dimos/navigation/bt_navigator/navigator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
Navigator module for coordinating global and local planning.
"""

from functools import partial
import threading
import time
from enum import Enum
from typing import Callable, Optional

from dimos.core import Module, In, Out, rpc
from dimos.core.blueprints import create_module_blueprint
from dimos.msgs.geometry_msgs import PoseStamped
from dimos.msgs.nav_msgs import OccupancyGrid
from dimos_lcm.std_msgs import String
Expand Down Expand Up @@ -121,6 +123,16 @@ def __init__(

logger.info("Navigator initialized with stuck detection")

@rpc
def set_HolonomicLocalPlanner_reset(self, callable) -> None:
self.reset_local_planner = callable
self.reset_local_planner.rpc = self.rpc

@rpc
def set_HolonomicLocalPlanner_atgl(self, callable) -> None:
self.check_goal_reached = callable
self.check_goal_reached.rpc = self.rpc

@rpc
def start(self):
super().start()
Expand Down Expand Up @@ -342,3 +354,6 @@ def stop_navigation(self) -> None:
self.recovery_server.reset() # Reset recovery server when stopping

logger.info("Navigator stopped")


behavior_tree_navigator = partial(create_module_blueprint, BehaviorTreeNavigator)
2 changes: 1 addition & 1 deletion dimos/navigation/frontier_exploration/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer
from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer, wavefront_frontier_explorer
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
for autonomous navigation using the dimos Costmap and Vector types.
"""

from functools import partial
import threading
from collections import deque
from dataclasses import dataclass
Expand All @@ -28,6 +29,7 @@
import numpy as np

from dimos.core import Module, In, Out, rpc
from dimos.core.blueprints import create_module_blueprint
from dimos.msgs.geometry_msgs import PoseStamped, Vector3
from dimos.msgs.nav_msgs import OccupancyGrid, CostValues
from dimos.utils.logging_config import setup_logger
Expand Down Expand Up @@ -810,3 +812,6 @@ def _exploration_loop(self):
f"No frontier found (attempt {consecutive_failures}/{max_consecutive_failures}). Retrying in 2 seconds..."
)
threading.Event().wait(2.0)


wavefront_frontier_explorer = partial(create_module_blueprint, WavefrontFrontierExplorer)
4 changes: 3 additions & 1 deletion dimos/navigation/global_planner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from dimos.navigation.global_planner.planner import AstarPlanner
from dimos.navigation.global_planner.planner import AstarPlanner, astar_planner
from dimos.navigation.global_planner.algo import astar

__all__ = ["AstarPlanner", "astar_planner", "astar"]
5 changes: 5 additions & 0 deletions dimos/navigation/global_planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Optional

from dimos.core import In, Module, Out, rpc
from dimos.core.blueprints import create_module_blueprint
from dimos.msgs.geometry_msgs import Pose, PoseStamped
from dimos.msgs.nav_msgs import OccupancyGrid, Path
from dimos.navigation.global_planner.algo import astar
Expand Down Expand Up @@ -216,3 +218,6 @@ def plan(self, goal: Pose) -> Optional[Path]:

logger.warning("No path found to the goal.")
return None


astar_planner = partial(create_module_blueprint, AstarPlanner)
Loading