diff --git a/dimos/core/README_BLUEPRINTS.md b/dimos/core/README_BLUEPRINTS.md index d54000cc6a..26143bd456 100644 --- a/dimos/core/README_BLUEPRINTS.md +++ b/dimos/core/README_BLUEPRINTS.md @@ -93,12 +93,12 @@ If you don't like the name you can always override it like in the next section. By default `LCMTransport` is used if the object supports `lcm_encode`. If it doesn't `pLCMTransport` is used (meaning "pickled LCM"). -You can override transports with the `with_transports` method. It returns a new blueprint in which the override is set. +You can override transports with the `transports` method. It returns a new blueprint in which the override is set. ```python blueprint = autoconnect(...) expanded_blueprint = autoconnect(blueprint, ...) -blueprint = blueprint.with_transports({ +blueprint = blueprint.transports({ ("image", Image): pSHMTransport( "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ), @@ -108,6 +108,47 @@ blueprint = blueprint.with_transports({ Note: `expanded_blueprint` does not get the transport overrides because it's created from the initial value of `blueprint`, not the second. +## Remapping connections + +Sometimes you need to rename a connection to match what other modules expect. You can use `remappings` to rename module connections: + +```python +class ConnectionModule(Module): + color_image: Out[Image] = None # Outputs on 'color_image' + +class ProcessingModule(Module): + rgb_image: In[Image] = None # Expects input on 'rgb_image' + +# Without remapping, these wouldn't connect automatically +# With remapping, color_image is renamed to rgb_image +blueprint = ( + autoconnect( + ConnectionModule.blueprint(), + ProcessingModule.blueprint(), + ) + .remappings([ + (ConnectionModule, 'color_image', 'rgb_image'), + ]) +) +``` + +After remapping: +- The `color_image` output from `ConnectionModule` is treated as `rgb_image` +- It automatically connects to any module with an `rgb_image` input of type `Image` +- The topic name becomes `/rgb_image` instead of `/color_image` + +If you want to override the topic, you still have to do it manually: + +```python +blueprint +.remappings([ + (ConnectionModule, 'color_image', 'rgb_image'), +]) +.transports({ + ("rgb_image", Image): LCMTransport("/custom/rgb/image", Image), +}) +``` + ## Overriding global configuration. Each module can optionally take a `global_config` option in `__init__`. E.g.: @@ -122,7 +163,7 @@ class ModuleA(Module): The config is normally taken from .env or from environment variables. But you can specifically override the values for a specific blueprint: ```python -blueprint = blueprint.with_global_config(n_dask_workers=8) +blueprint = blueprint.global_config(n_dask_workers=8) ``` ## Calling the methods of other modules @@ -213,7 +254,7 @@ This returns a `ModuleCoordinator` instance that manages all deployed modules. You can block the thread until it exits with: ```python -module_coordinator.wait_until_shutdown() +module_coordinator.loop() ``` -This will wait for Ctrl+C and then automatically stop all modules and clean up resources. \ No newline at end of file +This will wait for Ctrl+C and then automatically stop all modules and clean up resources. diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 53f20a0bfb..8743a46815 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -46,25 +46,44 @@ class ModuleBlueprint: class ModuleBlueprintSet: blueprints: tuple[ModuleBlueprint, ...] # TODO: Replace Any - transports: Mapping[tuple[str, type], Any] = field(default_factory=lambda: MappingProxyType({})) + transport_map: Mapping[tuple[str, type], Any] = field( + default_factory=lambda: MappingProxyType({}) + ) global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) + remapping_map: Mapping[tuple[type[Module], str], str] = field( + default_factory=lambda: MappingProxyType({}) + ) - def with_transports(self, transports: dict[tuple[str, type], Any]) -> "ModuleBlueprintSet": + def transports(self, transports: dict[tuple[str, type], Any]) -> "ModuleBlueprintSet": return ModuleBlueprintSet( blueprints=self.blueprints, - transports=MappingProxyType({**self.transports, **transports}), + transport_map=MappingProxyType({**self.transport_map, **transports}), global_config_overrides=self.global_config_overrides, + remapping_map=self.remapping_map, ) - def with_global_config(self, **kwargs: Any) -> "ModuleBlueprintSet": + def global_config(self, **kwargs: Any) -> "ModuleBlueprintSet": return ModuleBlueprintSet( blueprints=self.blueprints, - transports=self.transports, + transport_map=self.transport_map, global_config_overrides=MappingProxyType({**self.global_config_overrides, **kwargs}), + remapping_map=self.remapping_map, + ) + + def remappings(self, remappings: list[tuple[type[Module], str, str]]) -> "ModuleBlueprintSet": + remappings_dict = dict(self.remapping_map) + for module, old, new in remappings: + remappings_dict[(module, old)] = new + + return ModuleBlueprintSet( + blueprints=self.blueprints, + transport_map=self.transport_map, + global_config_overrides=self.global_config_overrides, + remapping_map=MappingProxyType(remappings_dict), ) def _get_transport_for(self, name: str, type: type) -> Any: - transport = self.transports.get((name, type), None) + transport = self.transport_map.get((name, type), None) if transport: return transport @@ -76,16 +95,21 @@ def _get_transport_for(self, name: str, type: type) -> Any: @cached_property def _all_name_types(self) -> set[tuple[str, type]]: - return { - (conn.name, conn.type) - for blueprint in self.blueprints - for conn in blueprint.connections - } + # Apply remappings to get the actual names that will be used + result = set() + for blueprint in self.blueprints: + for conn in blueprint.connections: + # Check if this connection should be remapped + remapped_name = self.remapping_map.get((blueprint.module, conn.name), conn.name) + result.add((remapped_name, conn.type)) + return result def _is_name_unique(self, name: str) -> bool: return sum(1 for n, _ in self._all_name_types if n == name) == 1 - def build(self, global_config: GlobalConfig) -> ModuleCoordinator: + def build(self, global_config: GlobalConfig | None = None) -> ModuleCoordinator: + if global_config is None: + global_config = GlobalConfig() global_config = global_config.model_copy(update=self.global_config_overrides) module_coordinator = ModuleCoordinator(global_config=global_config) @@ -100,18 +124,27 @@ def build(self, global_config: GlobalConfig) -> ModuleCoordinator: kwargs["global_config"] = global_config module_coordinator.deploy(blueprint.module, *blueprint.args, **kwargs) - # Gather all the In/Out connections. + # Gather all the In/Out connections with remapping applied. connections = defaultdict(list) + # Track original name -> remapped name for each module + module_conn_mapping = defaultdict(dict) + for blueprint in self.blueprints: for conn in blueprint.connections: - connections[conn.name, conn.type].append(blueprint.module) - - # Connect all In/Out connections by name and type. - for name, type in connections.keys(): - transport = self._get_transport_for(name, type) - for module in connections[(name, type)]: + # Check if this connection should be remapped + remapped_name = self.remapping_map.get((blueprint.module, conn.name), conn.name) + # Store the mapping for later use + module_conn_mapping[blueprint.module][conn.name] = remapped_name + # Group by remapped name and type + connections[remapped_name, conn.type].append((blueprint.module, conn.name)) + + # Connect all In/Out connections by remapped name and type. + for remapped_name, type in connections.keys(): + transport = self._get_transport_for(remapped_name, type) + for module, original_name in connections[(remapped_name, type)]: instance = module_coordinator.get_instance(module) - getattr(instance, name).transport = transport + # Use the remote method to set transport on Dask actors + instance.set_transport(original_name, transport) # Gather all RPC methods. rpc_methods = {} @@ -164,15 +197,17 @@ def create_module_blueprint(module: type[Module], *args: Any, **kwargs: Any) -> def autoconnect(*blueprints: ModuleBlueprintSet) -> ModuleBlueprintSet: all_blueprints = tuple(_eliminate_duplicates([bp for bs in blueprints for bp in bs.blueprints])) - all_transports = dict(sum([list(x.transports.items()) for x in blueprints], [])) + all_transports = dict(sum([list(x.transport_map.items()) for x in blueprints], [])) all_config_overrides = dict( sum([list(x.global_config_overrides.items()) for x in blueprints], []) ) + all_remappings = dict(sum([list(x.remapping_map.items()) for x in blueprints], [])) return ModuleBlueprintSet( blueprints=all_blueprints, - transports=MappingProxyType(all_transports), + transport_map=MappingProxyType(all_transports), global_config_overrides=MappingProxyType(all_config_overrides), + remapping_map=MappingProxyType(all_remappings), ) diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index e25184c351..96e56f914f 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -23,7 +23,6 @@ class GlobalConfig(BaseSettings): n_dask_workers: int = 2 model_config = SettingsConfigDict( - env_prefix="DIMOS_", env_file=".env", env_file_encoding="utf-8", extra="ignore", diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 6eb916fda3..43c522305f 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -63,7 +63,7 @@ def start_all_modules(self) -> None: def get_instance(self, module: Type[T]) -> T | None: return self._deployed_modules.get(module) - def wait_until_shutdown(self) -> None: + def loop(self) -> None: try: while True: time.sleep(0.1) diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index edce54f2e1..4025d183ca 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -131,18 +131,18 @@ def test_autoconnect(): ) -def test_with_transports(): +def test_transports(): custom_transport = LCMTransport("/custom_topic", Data1) - blueprint_set = autoconnect(module_a(), module_b()).with_transports( + blueprint_set = autoconnect(module_a(), module_b()).transports( {("data1", Data1): custom_transport} ) - assert ("data1", Data1) in blueprint_set.transports - assert blueprint_set.transports[("data1", Data1)] == custom_transport + assert ("data1", Data1) in blueprint_set.transport_map + assert blueprint_set.transport_map[("data1", Data1)] == custom_transport -def test_with_global_config(): - blueprint_set = autoconnect(module_a(), module_b()).with_global_config(option1=True, option2=42) +def test_global_config(): + blueprint_set = autoconnect(module_a(), module_b()).global_config(option1=True, option2=42) assert "option1" in blueprint_set.global_config_overrides assert blueprint_set.global_config_overrides["option1"] is True @@ -183,3 +183,60 @@ def test_build_happy_path(): finally: coordinator.stop() + + +def test_remapping(): + """Test that remapping connections works correctly.""" + pubsub.lcm.autoconf() + + # Define test modules with connections that will be remapped + class SourceModule(Module): + color_image: Out[Data1] = None # Will be remapped to 'remapped_data' + + class TargetModule(Module): + remapped_data: In[Data1] = None # Receives the remapped connection + + # Create blueprint with remapping + blueprint_set = autoconnect( + SourceModule.blueprint(), + TargetModule.blueprint(), + ).remappings( + [ + (SourceModule, "color_image", "remapped_data"), + ] + ) + + # Verify remappings are stored correctly + assert (SourceModule, "color_image") in blueprint_set.remapping_map + assert blueprint_set.remapping_map[(SourceModule, "color_image")] == "remapped_data" + + # Verify that remapped names are used in name resolution + assert ("remapped_data", Data1) in blueprint_set._all_name_types + # The original name shouldn't be in the name types since it's remapped + assert ("color_image", Data1) not in blueprint_set._all_name_types + + # Build and verify connections work + coordinator = blueprint_set.build(GlobalConfig()) + + try: + source_instance = coordinator.get_instance(SourceModule) + target_instance = coordinator.get_instance(TargetModule) + + assert source_instance is not None + assert target_instance is not None + + # Both should have transports set + assert source_instance.color_image.transport is not None + assert target_instance.remapped_data.transport is not None + + # They should be using the same transport (connected) + assert ( + source_instance.color_image.transport.topic + == target_instance.remapped_data.transport.topic + ) + + # The topic should be /remapped_data since that's the remapped name + assert target_instance.remapped_data.transport.topic == "/remapped_data" + + finally: + coordinator.stop() diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 2eef48855f..096af58b94 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -22,6 +22,8 @@ "unitree-go2-shm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_shm", "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", "demo-osm": "dimos.mapping.osm.demo_osm:demo_osm", + "demo-remapping": "dimos.robot.unitree_webrtc.demo_remapping:remapping", + "demo-remapping-transport": "dimos.robot.unitree_webrtc.demo_remapping:remapping_and_transport", } diff --git a/dimos/robot/cli/README.md b/dimos/robot/cli/README.md index 164fc8538c..da1d7443da 100644 --- a/dimos/robot/cli/README.md +++ b/dimos/robot/cli/README.md @@ -51,10 +51,10 @@ class GlobalConfig(BaseSettings): Configuration values can be set from multiple places in order of precedence (later entries override earlier ones): - Default value defined on GlobalConfig. (`use_simulation = False`) -- Value defined in `.env` (`DIMOS_USE_SIMULATION=true`) -- Value in the environment variable (`DIMOS_USE_SIMULATION=true`) +- Value defined in `.env` (`USE_SIMULATION=true`) +- Value in the environment variable (`USE_SIMULATION=true`) - Value coming from the CLI (`--use-simulation` or `--no-use-simulation`) -- Value defined on the blueprint (`blueprint.with_global_config(use_simulation=True)`) +- Value defined on the blueprint (`blueprint.global_config(use_simulation=True)`) For environment variables/`.env` values, you have to prefix the name with `DIMOS_`. diff --git a/dimos/robot/cli/dimos_robot.py b/dimos/robot/cli/dimos_robot.py index 5b589b3d69..8a40f76eb6 100644 --- a/dimos/robot/cli/dimos_robot.py +++ b/dimos/robot/cli/dimos_robot.py @@ -113,7 +113,7 @@ def run( blueprint = autoconnect(blueprint, *loaded_modules) dimos = blueprint.build(global_config=config) - dimos.wait_until_shutdown() + dimos.loop() @main.command() diff --git a/dimos/robot/unitree_webrtc/demo_remapping.py b/dimos/robot/unitree_webrtc/demo_remapping.py new file mode 100644 index 0000000000..a0b594f95a --- /dev/null +++ b/dimos/robot/unitree_webrtc/demo_remapping.py @@ -0,0 +1,30 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.transport import LCMTransport +from dimos.msgs.sensor_msgs import Image +from dimos.robot.unitree_webrtc.unitree_go2 import ConnectionModule +from dimos.robot.unitree_webrtc.unitree_go2_blueprints import standard + +remapping = standard.remappings( + [ + (ConnectionModule, "color_image", "rgb_image"), + ] +) + +remapping_and_transport = remapping.transports( + { + ("rgb_image", Image): LCMTransport("/go2/color_image", Image), + } +) diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index af13dc20bc..47e4ce6c8c 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -54,8 +54,8 @@ websocket_vis(), foxglove_bridge(), ) - .with_global_config(n_dask_workers=4) - .with_transports( + .global_config(n_dask_workers=4) + .transports( # These are kept the same so that we don't have to change foxglove configs. # Although we probably should. { @@ -74,8 +74,8 @@ depth_module(), utilization(), ) - .with_global_config(n_dask_workers=8) - .with_transports( + .global_config(n_dask_workers=8) + .transports( { ("depth_image", Image): LCMTransport("/go2/depth_image", Image), } @@ -83,7 +83,7 @@ ) standard_with_shm = autoconnect( - standard.with_transports( + standard.transports( { ("color_image", Image): pSHMTransport( "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE