Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dimos/simulation/mujoco/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

import mujoco
import numpy as np
import onnxruntime as rt # type: ignore[import-untyped]
import onnxruntime as ort # type: ignore[import-untyped]

from dimos.simulation.mujoco.input_controller import InputController
from dimos.utils.logging_config import setup_logger

logger = setup_logger()


class OnnxController(ABC):
Expand All @@ -37,7 +40,8 @@ def __init__(
drift_compensation: list[float] | None = None,
) -> None:
self._output_names = ["continuous_actions"]
self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"])
self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding logging to show which execution provider is actually being used after initialization. This would help with debugging and verifying that GPU acceleration is working when available.

Similar to the pattern used in dimos/agents_deprecated/memory/image_embedding.py:89-91:

self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers())
actual_providers = self._policy.get_providers()
logger.info(f"Loaded policy with providers: {actual_providers}")

This is especially useful since get_available_providers() returns all available providers, but the InferenceSession may only successfully initialize with a subset of them.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from hardcoded ["CPUExecutionProvider"] to dynamic ort.get_available_providers() lacks error handling and observability. While the implementation is technically correct (contrary to the previous review thread), there are important considerations:

Issues:

  1. No logging: Users won't know which execution provider is actually being used, making debugging and performance analysis difficult
  2. No error handling: If provider initialization fails (e.g., CUDA provider available but GPU out of memory), there's no graceful fallback or informative error message
  3. Silent performance changes: Different providers have different performance characteristics and numerical precision - users should be informed which one is selected

Recommendation:
Add logging after initialization to show which provider was selected, similar to the pattern used in dimos/agents_deprecated/memory/image_embedding.py:89-91:

self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers())
actual_providers = self._policy.get_providers()
# Log: f"Initialized ONNX policy with providers: {actual_providers}"

This would help users understand which execution provider is being used and debug any performance or behavior differences.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

logger.info(f"Loaded policy: {policy_path} with providers: {self._policy.get_providers()}")

self._action_scale = action_scale
self._default_angles = default_angles
Expand Down
Loading