From f930e3ee73d75e8ca5cd3bb2190afa17a4afbd82 Mon Sep 17 00:00:00 2001 From: Miguel Villa Floran Date: Mon, 12 Jan 2026 15:50:56 -0800 Subject: [PATCH 1/3] refactor(policy): update inference session initialization --- dimos/simulation/mujoco/policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 00491b4379..52a30d9f30 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -20,7 +20,7 @@ import mujoco import numpy as np -import onnxruntime as rt # type: ignore[import-untyped] +import onnxruntime as ort # type: ignore[import-untyped] from dimos.simulation.mujoco.input_controller import InputController @@ -37,7 +37,7 @@ def __init__( drift_compensation: list[float] | None = None, ) -> None: self._output_names = ["continuous_actions"] - self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) + self._policy = ort.InferenceSession(policy_path, providers=[ort.get_available_providers()]) self._action_scale = action_scale self._default_angles = default_angles From 9e0f7516c6c364901e47267eb2eea931c3fcae84 Mon Sep 17 00:00:00 2001 From: Miguel Villa Floran Date: Mon, 12 Jan 2026 16:01:40 -0800 Subject: [PATCH 2/3] refactor(policy): simplify inference session provider initialization --- dimos/simulation/mujoco/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 52a30d9f30..5137e16113 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -37,7 +37,7 @@ def __init__( drift_compensation: list[float] | None = None, ) -> None: self._output_names = ["continuous_actions"] - self._policy = ort.InferenceSession(policy_path, providers=[ort.get_available_providers()]) + self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers()) self._action_scale = action_scale self._default_angles = default_angles From 9be157da828e436cc54fae29fe9ed96a75875083 Mon Sep 17 00:00:00 2001 From: Miguel Villa Floran Date: Mon, 12 Jan 2026 19:41:08 -0800 Subject: [PATCH 3/3] Log the policy directory and provider --- dimos/simulation/mujoco/policy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 5137e16113..212c7ac60a 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -23,6 +23,9 @@ import onnxruntime as ort # type: ignore[import-untyped] from dimos.simulation.mujoco.input_controller import InputController +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() class OnnxController(ABC): @@ -38,6 +41,7 @@ def __init__( ) -> None: self._output_names = ["continuous_actions"] self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers()) + logger.info(f"Loaded policy: {policy_path} with providers: {self._policy.get_providers()}") self._action_scale = action_scale self._default_angles = default_angles