diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 00491b4379..212c7ac60a 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -20,9 +20,12 @@ import mujoco import numpy as np -import onnxruntime as rt # type: ignore[import-untyped] +import onnxruntime as ort # type: ignore[import-untyped] from dimos.simulation.mujoco.input_controller import InputController +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() class OnnxController(ABC): @@ -37,7 +40,8 @@ def __init__( drift_compensation: list[float] | None = None, ) -> None: self._output_names = ["continuous_actions"] - self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) + self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers()) + logger.info(f"Loaded policy: {policy_path} with providers: {self._policy.get_providers()}") self._action_scale = action_scale self._default_angles = default_angles