From c09ebfe03694100800b9bc4b556a13671a8f5ab3 Mon Sep 17 00:00:00 2001 From: Lotus Li Date: Thu, 14 Aug 2025 14:58:50 -0700 Subject: [PATCH 1/2] Add visualization framework and Teleop visualizers --- CONTRIBUTORS.md | 1 + scripts/tools/record_demos.py | 24 +- scripts/tools/teleop_visualization_manager.py | 337 ++++++++++ .../tools/test/scene_visualization_sample.py | 70 ++ .../isaaclab/isaaclab/devices/device_base.py | 4 +- .../humanoid/fourier/gr1t2_retargeter.py | 35 +- .../isaaclab/envs/manager_based_rl_env.py | 10 +- .../isaaclab/ui/xr_widgets/__init__.py | 3 +- .../ui/xr_widgets/instruction_widget.py | 93 ++- .../ui/xr_widgets/scene_visualization.py | 607 ++++++++++++++++++ 10 files changed, 1143 insertions(+), 41 deletions(-) create mode 100644 scripts/tools/teleop_visualization_manager.py create mode 100644 scripts/tools/test/scene_visualization_sample.py create mode 100644 source/isaaclab/isaaclab/ui/xr_widgets/scene_visualization.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 3c683ebe4f06..562cae8979d8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -85,6 +85,7 @@ Guidelines for modifications: * Kourosh Darvish * Kousheek Chakraborty * Lionel Gulich +* Lotus Li * Louis Le Lay * Lorenz Wellhausen * Lukas Fröhlich diff --git a/scripts/tools/record_demos.py b/scripts/tools/record_demos.py index d9bacd5c2537..dc7bd93ebb7e 100644 --- a/scripts/tools/record_demos.py +++ b/scripts/tools/record_demos.py @@ -108,7 +108,8 @@ import isaaclab_tasks # noqa: F401 from isaaclab_tasks.utils.parse_cfg import parse_env_cfg - +from isaaclab.ui.xr_widgets import XRVisualization +from teleop_visualization_manager import TeleopVisualizationManager class RateLimiter: """Convenience class for enforcing rates in loops.""" @@ -469,16 +470,24 @@ def stop_recording_instance(): label_text = f"Recorded {current_recorded_demo_count} successful demonstrations." print(label_text) + # Check if we've reached the desired number of demos + if args_cli.num_demos > 0 and env.recorder_manager.exported_successful_episode_count >= args_cli.num_demos: + label_text = f"All {current_recorded_demo_count} demonstrations recorded.\nExiting the app." + instruction_display.show_demo(label_text) + print(label_text) + target_time = time.time() + 0.8 + while time.time() < target_time: + if rate_limiter: + rate_limiter.sleep(env) + else: + env.sim.render() + break + # Handle reset if requested if should_reset_recording_instance: success_step_count = handle_reset(env, success_step_count, instruction_display, label_text) should_reset_recording_instance = False - # Check if we've reached the desired number of demos - if args_cli.num_demos > 0 and env.recorder_manager.exported_successful_episode_count >= args_cli.num_demos: - print(f"All {args_cli.num_demos} demonstrations recorded. Exiting the app.") - break - # Check if simulation is stopped if env.sim.is_stopped(): break @@ -512,6 +521,9 @@ def main() -> None: # Set up output directories output_dir, output_file_name = setup_output_directories() + # Assign the teleop visualization manager to the visualization system + XRVisualization.assign_manager(TeleopVisualizationManager) + # Create and configure environment global env_cfg # Make env_cfg available to setup_teleop_device env_cfg, success_term = create_environment_config(output_dir, output_file_name) diff --git a/scripts/tools/teleop_visualization_manager.py b/scripts/tools/teleop_visualization_manager.py new file mode 100644 index 000000000000..e87729f710e8 --- /dev/null +++ b/scripts/tools/teleop_visualization_manager.py @@ -0,0 +1,337 @@ +import isaaclab.sim as sim_utils +from isaaclab.markers import VisualizationMarkers, VisualizationMarkersCfg +from isaaclab.ui.xr_widgets.instruction_widget import hide_instruction +from isaaclab.ui.xr_widgets import VisualizationManager, TriggerType, DataCollector, VisualizationManager +from pxr import Gf, Usd +import numpy as np +from isaaclab.devices.openxr.openxr_device import OpenXRDevice +import torch +from typing import Any, Final +import omni.kit.app +import json +import carb +from omni.kit.viewport.utility.camera_state import ViewportCameraState + +def send_message_to_client(message: dict): + """Send a message to the CloudXR client. + + Args: + message (dict or str): Message to send (will be converted to JSON if it's a dict) + """ + if isinstance(message, dict): + message_str = json.dumps(message) + else: + message_str = message + + omni.kit.app.queue_event("omni.kit.cloudxr.send_message", payload={"message": message_str}) + +class TeleopVisualizationManager(VisualizationManager): + """Specialized visualization manager for teleoperation scenarios. + For sample and debug use. + + Provides teleoperation-specific visualization features including: + - IK error handling and display + - Hand position tracking and range indicators + - Real-time data panels to display data in DataCollector + """ + + def __init__(self, data_collector: DataCollector): + """Initialize the teleop visualization manager and register callbacks. + + Args: + data_collector: DataCollector instance to read data for visualization use. + """ + super().__init__(data_collector) + + # Register the event alias for sending messages to the CloudXR client + carb_event = carb.events.type_from_string("omni.kit.cloudxr.send_message") + omni.kit.app.register_event_alias(carb_event, "omni.kit.cloudxr.send_message") + + # Config whether to visualize the markers. Default to False. + self._enable_visualization = False + + # Register callback to update the enable_visualization + self.register_callback(TriggerType.TRIGGER_ON_EVENT, {"event_name": "enable_teleop_visualization"}, self._handle_enable_visualization) + + # Handle error event + self._error_text_color = 0xFF0000FF + self.ik_error_widget_id = "/ik_solver_failed" + + #self.display_widget("IK Error Detected", self.ik_error_widget_id, VisualizationManager.message_widget_preset() | {"text_color": self._error_text_color, "display_duration": None}) + + self.register_callback(TriggerType.TRIGGER_ON_EVENT, {"event_name": "ik_error"}, self._handle_ik_error) + + # Handle torque skeleton + self._num_open_xr_hand_joints = 52 + marker_cfg = VisualizationMarkersCfg( + prim_path="/Visuals/skeleton_joints", + markers={ + "grey": sim_utils.SphereCfg( + radius=0.005, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.3, 0.5, 0.3)), + ), + "green": sim_utils.SphereCfg( + radius=0.005, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 0.0)), + ), + "yellow": sim_utils.SphereCfg( + radius=0.005, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 1.0, 0.0)), + ), + "red": sim_utils.SphereCfg( + radius=0.005, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.0, 0.0)), + ), + }, + ) + self._markers_joints = VisualizationMarkers(marker_cfg) + self._markers_joints.set_visibility(False) + + marker_cfg = VisualizationMarkersCfg( + prim_path="/Visuals/skeleton_lines", + markers={ + "grey": sim_utils.CylinderCfg( + radius=0.001, + height=1, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.3, 0.5, 0.3)), + ), + "green": sim_utils.CylinderCfg( + radius=0.001, + height=1, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 0.0)), + ), + "yellow": sim_utils.CylinderCfg( + radius=0.001, + height=1, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 1.0, 0.0)), + ), + "red": sim_utils.CylinderCfg( + radius=0.001, + height=1, + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.0, 0.0)), + ), + }, + ) + self._markers_lines = VisualizationMarkers(marker_cfg) + self._markers_lines.set_visibility(False) + # Expect to update the skeleton every frame + self.register_callback(TriggerType.TRIGGER_ON_UPDATE, {}, self._update_torque_skeleton) + + # Todo: enable this after StreamSDK supports sending large data through the channel. + # For now, client app crashes when receiving large data (~1KB per frame) + # self.register_callback(TriggerType.TRIGGER_ON_UPDATE, {}, self._send_visualization_data_to_client) + + def _handle_enable_visualization(self, mgr: VisualizationManager, data_collector: DataCollector, enabled: Any) -> None: + """Update the enable visualization. + + Args: + data_collector: DataCollector instance containing current data + """ + if enabled is not None and enabled != self._enable_visualization: + self._enable_visualization = enabled + self._markers_joints.set_visibility(enabled) + self._markers_lines.set_visibility(enabled) + + def _handle_ik_error(self, mgr: VisualizationManager, data_collector: DataCollector, params: Any = None) -> None: + """Handle IK error events by displaying an error message widget. + + Args: + data_collector: DataCollector instance (unused in this handler) + """ + # Todo: move display_widget to instruction_widget.py + if not hasattr(mgr, "_ik_error_widget_timer"): + self.display_widget("IK Error Detected", mgr.ik_error_widget_id, VisualizationManager.message_widget_preset() | {"text_color": self._error_text_color, "display_duration": None}) + mgr._ik_error_widget_timer = mgr.register_callback(TriggerType.TRIGGER_ON_PERIOD, {"period": 3.0, "initial_countdown": 3.0}, self._hide_ik_error_widget) + if mgr._ik_error_widget_timer is None: + mgr.cancel_rule(TriggerType.TRIGGER_ON_PERIOD, mgr._ik_error_widget_timer) + mgr.cancel_rule(TriggerType.TRIGGER_ON_EVENT, "ik_solver_failed") + raise RuntimeWarning("Failed to register IK error widget timer") + else: + mgr._ik_error_widget_timer.countdown = 3.0 + + def _hide_ik_error_widget(self, mgr: VisualizationManager, data_collector: DataCollector) -> None: + """Hide the IK error widget. + + Args: + data_collector: DataCollector instance (unused in this handler) + """ + + hide_instruction(mgr.ik_error_widget_id) + mgr.cancel_rule(TriggerType.TRIGGER_ON_PERIOD, mgr._ik_error_widget_timer) + delattr(mgr, "_ik_error_widget_timer") + + + def _update_torque_skeleton(self, mgr: VisualizationManager, data_collector: DataCollector) -> None: + """Update the torque skeleton. + + Args: + data_collector: DataCollector instance containing current data + """ + if not mgr._enable_visualization: + return + + data = data_collector.get_data("device_raw_data") + sim_device = data_collector.get_data("sim_device") + if data is None or sim_device is None: + return + + left_hand_poses = data.get(OpenXRDevice.TrackingTarget.HAND_LEFT) + right_hand_poses = data.get(OpenXRDevice.TrackingTarget.HAND_RIGHT) + + joints_position = np.zeros((self._num_open_xr_hand_joints, 3)) + + # Extract joint positions from both hands + left_joints = np.array([pose[:3] for pose in left_hand_poses.values()]) + right_joints = np.array([pose[:3] for pose in right_hand_poses.values()]) + + # Fill the first part with left hand joints + num_left = len(left_joints) + joints_position[:num_left] = left_joints + + # Fill the second part with right hand joints + joints_position[num_left:num_left + len(right_joints)] = right_joints + + viewport_api = omni.kit.viewport.utility.get_active_viewport() + # camera_state = ViewportCameraState(viewport_api.get_active_camera(), viewport_api) + # camera_position = np.array(camera_state.position) + + camera_state = ViewportCameraState(viewport_api.get_active_camera(), viewport_api) + world_transform = camera_state.usd_camera.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) + camera_position = np.array(world_transform.ExtractTranslation()) + + # Move all joints closer to the camera for better visualization + direction_to_camera = camera_position - joints_position + distance_to_camera = np.linalg.norm(direction_to_camera, axis=1, keepdims=True) + joints_position += direction_to_camera / (distance_to_camera + 1e-8) * 0.03 + + # Calculate midpoints between consecutive joints for line visualization + joints_midpoints = (joints_position[:-1] + joints_position[1:]) / 2 + + # Calculate direction vectors + directions = joints_position[1:] - joints_position[:-1] + + # Remove unnecessary joints + indices_to_remove = [0, 1, 5, 10, 15, 20, 25, 26, 27, 31, 36, 41, 46] + joints_midpoints = np.delete(joints_midpoints, indices_to_remove, axis=0) + directions = np.delete(directions, indices_to_remove, axis=0) + + # Calculate lengths + lengths = np.linalg.norm(directions, axis=1) + + # Normalize direction vectors + normalized_directions = directions / (lengths[:, np.newaxis] + 1e-8) # Add small epsilon to avoid division by zero + + # Calculate orientations (quaternions) to align cylinders with direction vectors + # Cylinder default axis is Z-axis [0, 0, 1] + default_axis = np.array([0, 0, 1]) + orientations = [] + + for direction in normalized_directions: + # Calculate rotation quaternion from default axis to target direction + dot = np.dot(default_axis, direction) + if dot > 0.9999: # Vectors are already aligned + quat = np.array([1, 0, 0, 0]) # Identity quaternion (w, x, y, z) + elif dot < -0.9999: # Vectors are opposite + # Find perpendicular axis for 180-degree rotation + perp = np.array([1, 0, 0]) if abs(default_axis[0]) < 0.9 else np.array([0, 1, 0]) + axis = np.cross(default_axis, perp) + if np.linalg.norm(axis) > 0: + axis = axis / np.linalg.norm(axis) + quat = np.array([0, axis[0], axis[1], axis[2]]) # 180-degree rotation + else: + # General case: calculate rotation axis and angle + axis = np.cross(default_axis, direction) + if np.linalg.norm(axis) > 0: + axis = axis / np.linalg.norm(axis) + angle = np.arccos(np.clip(dot, -1, 1)) + half_angle = angle / 2 + w = np.cos(half_angle) + xyz = axis * np.sin(half_angle) + quat = np.array([w, xyz[0], xyz[1], xyz[2]]) # (w, x, y, z) + + orientations.append(quat) + + orientations = np.array(orientations) + + # Set scales: keep radius at 1, set height to the distance between points + scales = np.column_stack([np.ones(len(lengths)), np.ones(len(lengths)), lengths]) + + colors_lines = np.zeros(len(joints_midpoints)) + colors_joints = np.zeros(len(joints_position)) + + joints_torque : Final = data_collector.get_data("joints_torque") + joints_torque_limit : Final = data_collector.get_data("joints_torque_limit") + joints_name : Final = data_collector.get_data("joints_name") + hand_torque_mapping : Final = data_collector.get_data("hand_torque_mapping") + + # enable_torque_color needs to be manually set by calling XRVisualization.set_attrs({"enable_torque_color": True}) + if getattr(mgr, "enable_torque_color") and joints_torque is not None and joints_torque_limit is not None and joints_name is not None and hand_torque_mapping is not None and len(hand_torque_mapping) == len(colors_lines) - 10: + # Insert empty strings at positions that are not fingers + hand_torque_mapping_copy = hand_torque_mapping.copy() + hand_torque_mapping_copy.append("") + torque_mapping_lines_index = np.array([-1, 0, 1, + -1, 2, 3, 4, + -1, 5, 6, 7, + -1, 8, 9, 10, + -1, 11, 12, 13, + -1, 14, 15, + -1, 16, 17, 18, + -1, 19, 20, 21, + -1, 22, 23, 24, + -1, 25, 26, 27]) + torque_mapping_joints_index = np.array([-1, -1, -1, 0, 1, 1, + -1, 2, 3, 4, 4, + -1, 5, 6, 7, 7, + -1, 8, 9, 10, 10, + -1, 11, 12, 13, 13, + -1, -1, -1, 14, 15, 15, + -1, 16, 17, 18, 18, + -1, 19, 20, 21, 21, + -1, 22, 23, 24, 24, + -1, 25, 26, 27, 27]) + torque_mapping_lines = np.array(hand_torque_mapping_copy)[torque_mapping_lines_index] + torque_mapping_joints = np.array(hand_torque_mapping_copy)[torque_mapping_joints_index] + + # Set colors: 0: grey, 1: green, 2: yellow, 3: red + for i, key in enumerate(torque_mapping_lines): + if key in joints_name: + ratio = joints_torque[joints_name.index(key)] / joints_torque_limit[joints_name.index(key)] + colors_lines[i] = 0 if ratio < 0.05 else 1 if ratio < 0.5 else 2 if ratio < 0.8 else 3 + for i, key in enumerate(torque_mapping_joints): + if key in joints_name: + ratio = joints_torque[joints_name.index(key)] / joints_torque_limit[joints_name.index(key)] + colors_joints[i] = 0 if ratio < 0.05 else 1 if ratio < 0.5 else 2 if ratio < 0.8 else 3 + + self._markers_joints.visualize( + translations=torch.tensor(joints_position, device=sim_device), marker_indices=torch.tensor(colors_joints, device=sim_device) + ) + self._markers_lines.visualize( + translations=torch.tensor(joints_midpoints, device=sim_device), + orientations=torch.tensor(orientations, device=sim_device), + scales=torch.tensor(scales, device=sim_device), + marker_indices=torch.tensor(colors_lines, device=sim_device) + ) + + def _send_visualization_data_to_client(self, mgr: VisualizationManager, data_collector: DataCollector) -> None: + """Send the data to the CloudXR client. + + Args: + data_collector: DataCollector instance containing current data + """ + dic = {"Type": "visualization_data"} + + ellipsoid = data_collector.get_data("manipulability_ellipsoid") + if ellipsoid is not None: + dic["ellipsoid"] = ellipsoid + distance_to_limit = data_collector.get_data("joints_distance_percentage_to_limit") + if distance_to_limit is not None: + dic["distance_to_limit"] = distance_to_limit + torque_to_limit = data_collector.get_data("joints_torque_percentage_of_limit") + if torque_to_limit is not None: + dic["torque_to_limit"] = torque_to_limit + joints_name = data_collector.get_data("joints_name") + if joints_name is not None: + dic["joints_name"] = joints_name + + if len(dic) > 1: + send_message_to_client(dic) diff --git a/scripts/tools/test/scene_visualization_sample.py b/scripts/tools/test/scene_visualization_sample.py new file mode 100644 index 000000000000..435b75fcd98a --- /dev/null +++ b/scripts/tools/test/scene_visualization_sample.py @@ -0,0 +1,70 @@ +from isaaclab.ui.xr_widgets import XRVisualization, TriggerType, VisualizationManager, DataCollector,update_instruction +from pxr import Gf + +def _sample_handle_ik_error(self, mgr: VisualizationManager, data_collector: DataCollector) -> None: + """Handle IK error events by displaying an error message widget. + + Args: + data_collector: DataCollector instance (unused in this handler) + """ + + self.display_widget("IK Error Detected", "/ik_error", VisualizationManager.message_widget_preset() | {"text_color": self._error_text_color}) + +def _sample_update_error_text_color(self, mgr: VisualizationManager, data_collector: DataCollector) -> None: + self._error_text_color = self._error_text_color + 0x100 + if self._error_text_color >= 0xFFFFFFFF: + self._error_text_color = 0xFF0000FF + +def _sample_update_left_panel(self, mgr, data_collector) -> None: + """Update the left panel with current data and update counter. + + Args: + data_collector: DataCollector instance containing current data + """ + left_panel_id = getattr(self, '_left_panel_id', None) + if left_panel_id is not None: + content = f"{mgr._left_panel_updated_times}\n{data_collector.make_panel_content()}" + update_instruction(left_panel_id, content) + mgr._left_panel_updated_times += 1 + +def _sample_update_right_panel(self, mgr, data_collector) -> None: + """Update the right panel with current data and update counter. + + Args: + data_collector: DataCollector instance containing current data + """ + right_panel_id = getattr(self, '_right_panel_id', None) + if right_panel_id is not None: + content = f"{mgr._right_panel_updated_times}\n{data_collector.make_panel_content()}" + update_instruction(right_panel_id, content) + mgr._right_panel_updated_times += 1 + +def apply_sample_visualization(): + # Error Message + XRVisualization.register_callback(TriggerType.TRIGGER_ON_EVENT, {"event_name": "ik_error"}, _sample_handle_ik_error) + + # Display a panel on the left to display DataCollector data + # Refresh periodically + # Todo: use a better way to add '/' to pathname + XRVisualization.set_attrs({ + "left_panel_id": "/left_panel", + "left_panel_translation": Gf.Vec3f(-2, 2.6, 2), + "left_panel_updated_times": 0, + "right_panel_updated_times": 0, + }) + XRVisualization.register_callback(TriggerType.TRIGGER_ON_PERIOD, {"period": 1.0}, _sample_update_left_panel) + + # Display a panel on the right to display DataCollector data + # Refresh when data changes + XRVisualization.set_attrs({ + "right_panel_id": "/right_panel", + "right_panel_translation": Gf.Vec3f(1.5, 2, 2), + }) + XRVisualization.register_callback(TriggerType.TRIGGER_ON_EVENT, {"event_name": "default_event_has_change"}, _sample_update_right_panel) + + # Change error text color every second + XRVisualization.set_attrs({ + "error_text_color": 0xFF0000FF, + }) + XRVisualization.register_callback(TriggerType.TRIGGER_ON_UPDATE, {}, _sample_update_error_text_color) + diff --git a/source/isaaclab/isaaclab/devices/device_base.py b/source/isaaclab/isaaclab/devices/device_base.py index b7955468cc11..54d7594941f7 100644 --- a/source/isaaclab/isaaclab/devices/device_base.py +++ b/source/isaaclab/isaaclab/devices/device_base.py @@ -12,7 +12,7 @@ from typing import Any from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg - +from isaaclab.ui.xr_widgets import XRVisualization @dataclass class DeviceCfg: @@ -110,6 +110,8 @@ def advance(self) -> torch.Tensor: """ raw_data = self._get_raw_data() + XRVisualization.push_data({"device_raw_data": raw_data}) + # If no retargeters, return raw data directly (not as a tuple) if not self._retargeters: return raw_data diff --git a/source/isaaclab/isaaclab/devices/openxr/retargeters/humanoid/fourier/gr1t2_retargeter.py b/source/isaaclab/isaaclab/devices/openxr/retargeters/humanoid/fourier/gr1t2_retargeter.py index 4548c0f99cba..282c1ff69424 100644 --- a/source/isaaclab/isaaclab/devices/openxr/retargeters/humanoid/fourier/gr1t2_retargeter.py +++ b/source/isaaclab/isaaclab/devices/openxr/retargeters/humanoid/fourier/gr1t2_retargeter.py @@ -12,12 +12,12 @@ import isaaclab.utils.math as PoseUtils from isaaclab.devices import OpenXRDevice from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg -from isaaclab.markers import VisualizationMarkers, VisualizationMarkersCfg # This import exception is suppressed because gr1_t2_dex_retargeting_utils depends on pinocchio which is not available on windows with contextlib.suppress(Exception): from .gr1_t2_dex_retargeting_utils import GR1TR2DexRetargeting +from isaaclab.ui.xr_widgets import XRVisualization @dataclass class GR1T2RetargeterCfg(RetargeterCfg): @@ -48,6 +48,20 @@ def __init__( hand_joint_names: List of robot hand joint names """ + XRVisualization.push_event("enable_teleop_visualization", cfg.enable_visualization) + XRVisualization.push_data({"sim_device": cfg.sim_device}) + hand_torque_mapping = ["L_thumb_proximal_pitch_joint", "L_thumb_distal_joint", + "L_index_proximal_joint", "L_index_intermediate_joint", "L_index_intermediate_joint", + "L_middle_proximal_joint", "L_middle_intermediate_joint", "L_middle_intermediate_joint", + "L_ring_proximal_joint", "L_ring_intermediate_joint", "L_ring_intermediate_joint", + "L_pinky_proximal_joint", "L_pinky_intermediate_joint", "L_pinky_intermediate_joint", + "R_thumb_proximal_pitch_joint", "R_thumb_distal_joint", + "R_index_proximal_joint", "R_index_intermediate_joint", "R_index_intermediate_joint", + "R_middle_proximal_joint", "R_middle_intermediate_joint", "R_middle_intermediate_joint", + "R_ring_proximal_joint", "R_ring_intermediate_joint", "R_ring_intermediate_joint", + "R_pinky_proximal_joint", "R_pinky_intermediate_joint", "R_pinky_intermediate_joint"] + XRVisualization.push_data({"hand_torque_mapping": hand_torque_mapping}) + self._hand_joint_names = cfg.hand_joint_names self._hands_controller = GR1TR2DexRetargeting(self._hand_joint_names) @@ -55,17 +69,6 @@ def __init__( self._enable_visualization = cfg.enable_visualization self._num_open_xr_hand_joints = cfg.num_open_xr_hand_joints self._sim_device = cfg.sim_device - if self._enable_visualization: - marker_cfg = VisualizationMarkersCfg( - prim_path="/Visuals/markers", - markers={ - "joint": sim_utils.SphereCfg( - radius=0.005, - visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.0, 0.0)), - ), - }, - ) - self._markers = VisualizationMarkers(marker_cfg) def retarget(self, data: dict) -> torch.Tensor: """Convert hand joint poses to robot end-effector commands. @@ -87,14 +90,6 @@ def retarget(self, data: dict) -> torch.Tensor: left_wrist = left_hand_poses.get("wrist") right_wrist = right_hand_poses.get("wrist") - if self._enable_visualization: - joints_position = np.zeros((self._num_open_xr_hand_joints, 3)) - - joints_position[::2] = np.array([pose[:3] for pose in left_hand_poses.values()]) - joints_position[1::2] = np.array([pose[:3] for pose in right_hand_poses.values()]) - - self._markers.visualize(translations=torch.tensor(joints_position, device=self._sim_device)) - # Create array of zeros with length matching number of joint names left_hands_pos = self._hands_controller.compute_left(left_hand_poses) indexes = [self._hand_joint_names.index(name) for name in self._hands_controller.get_left_joint_names()] diff --git a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py index c29d203b07b9..139f45015baa 100644 --- a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py +++ b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py @@ -17,7 +17,7 @@ from isaaclab.managers import CommandManager, CurriculumManager, RewardManager, TerminationManager from isaaclab.ui.widgets import ManagerLiveVisualizer - +from isaaclab.ui.xr_widgets import XRVisualization from .common import VecEnvStepReturn from .manager_based_env import ManagerBasedEnv from .manager_based_rl_env_cfg import ManagerBasedRLEnvCfg @@ -196,6 +196,14 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: # update buffers at sim dt self.scene.update(dt=self.physics_dt) + # Todo: torque limit doesn't need updating every frame + # get joint torque limits + joints_torque_limit = self.scene["robot"].data.joint_effort_limits[0] + # get joint torque + joints_torque = self.scene["robot"].data.applied_torque[0] + joints_name = self.scene["robot"].data.joint_names + XRVisualization.push_data({"joints_torque": joints_torque, "joints_torque_limit": joints_torque_limit, "joints_name": joints_name}) + # post-step: # -- update env counters (used for curriculum generation) self.episode_length_buf += 1 # step in current episode (per env) diff --git a/source/isaaclab/isaaclab/ui/xr_widgets/__init__.py b/source/isaaclab/isaaclab/ui/xr_widgets/__init__.py index 5b9b39ec156c..26f74c8d948e 100644 --- a/source/isaaclab/isaaclab/ui/xr_widgets/__init__.py +++ b/source/isaaclab/isaaclab/ui/xr_widgets/__init__.py @@ -2,4 +2,5 @@ # All rights reserved. # # SPDX-License-Identifier: BSD-3-Clause -from .instruction_widget import SimpleTextWidget, show_instruction +from .instruction_widget import SimpleTextWidget, show_instruction, hide_instruction, update_instruction +from .scene_visualization import XRVisualization, TriggerType, DataCollector, VisualizationManager diff --git a/source/isaaclab/isaaclab/ui/xr_widgets/instruction_widget.py b/source/isaaclab/isaaclab/ui/xr_widgets/instruction_widget.py index 65de79f155b2..297b1eec4faf 100644 --- a/source/isaaclab/isaaclab/ui/xr_widgets/instruction_widget.py +++ b/source/isaaclab/isaaclab/ui/xr_widgets/instruction_widget.py @@ -6,14 +6,18 @@ import asyncio import functools import textwrap -from typing import Any, TypeAlias +from typing import Any, TypeAlias, Optional +from collections.abc import Callable import omni.kit.commands import omni.ui as ui from isaacsim.core.utils.prims import delete_prim, get_prim_at_path from omni.kit.xr.scene_view.utils import UiContainer, WidgetComponent from omni.kit.xr.scene_view.utils.spatial_source import SpatialSource -from pxr import Gf +from pxr import Gf, Usd +from isaacsim.core.prims import XFormPrim +import isaacsim.core.utils.stage as stage_utils +from isaacsim.core.utils.stage import get_current_stage Vec3Type: TypeAlias = Gf.Vec3f | Gf.Vec3d @@ -42,7 +46,10 @@ def _build_ui(self): with ui.ZStack(): ui.Rectangle(style={"Rectangle": {"background_color": 0xFF454545, "border_radius": 0.1}}) with ui.VStack(alignment=ui.Alignment.CENTER): - self._ui_label = ui.Label(self._text, style=self._style, alignment=ui.Alignment.CENTER) + # Ensure text is never None for ui.Label + display_text = self._text if self._text is not None else "Simple Text" + self._ui_label = ui.Label(display_text, style=self._style, alignment=ui.Alignment.CENTER) + # self._ui_label = ui.Label(display_text, style=self._style, alignment=ui.Alignment.LEFT_TOP) def compute_widget_dimensions( @@ -68,7 +75,6 @@ def compute_widget_dimensions( actual_height = len(lines) * line_height return actual_width, actual_height, lines - def show_instruction( text: str, prim_path_source: str | None = None, @@ -77,7 +83,12 @@ def show_instruction( max_width: float = 2.5, min_width: float = 1.0, # Prevent widget from being too narrow. font_size: float = 0.1, + text_color: int = 0xFFFFFFFF, target_prim_path: str = "/newPrim", + callback_on_hide: Optional[Callable] = None, + preferred_width: float = 0.0, + preferred_height: float = 0.0, + is_billboard: bool = True ) -> UiContainer | None: """ Create and display the instruction widget based on the given text. @@ -122,7 +133,15 @@ def show_instruction( delete_prim(target_prim_path) # Compute dimensions and wrap text. + if preferred_width > 0.0: + max_width = preferred_width + min_width = preferred_width + width, height, lines = compute_widget_dimensions(text, font_size, max_width, min_width) + + if preferred_height > 0.0: + height = preferred_height + wrapped_text = "\n".join(lines) # Create the widget component. @@ -131,7 +150,7 @@ def show_instruction( width=width, height=height, resolution_scale=300, - widget_args=[wrapped_text, {"font_size": font_size}], + widget_args=[wrapped_text, {"font_size": font_size, "color": text_color}], ) copied_prim = omni.kit.commands.execute( @@ -142,14 +161,28 @@ def show_instruction( copy_to_introducing_layer=False, ) +# Todo: stop copy prim, copy transform instead + # stage = get_current_stage() + + # if stage.GetPrimAtPath(prim_path_source).IsValid(): + # prim = stage.GetPrimAtPath(prim_path_source) + + # new_prim = XFormPrim(prim_paths_expr=target_prim_path, positions=prim.) + # assert new_prim.is_valid() + space_stack = [] if copied_prim is not None: space_stack.append(SpatialSource.new_prim_path_source(target_prim_path)) - space_stack.extend([ - SpatialSource.new_translation_source(translation), - SpatialSource.new_look_at_camera_source(), - ]) + if is_billboard: + space_stack.extend([ + SpatialSource.new_translation_source(translation), + SpatialSource.new_look_at_camera_source(), + ]) + else: + space_stack.extend([ + SpatialSource.new_translation_source(translation), + ]) # Create the UI container with the widget. container = UiContainer( @@ -160,17 +193,20 @@ def show_instruction( # Schedule auto-hide after the specified display_duration if provided. if display_duration: - timer = asyncio.get_event_loop().call_later(display_duration, functools.partial(hide, target_prim_path)) + timer = asyncio.get_event_loop().call_later(display_duration, functools.partial(hide_instruction, target_prim_path, callback_on_hide)) camera_facing_widget_timers[target_prim_path] = timer return container - -def hide(target_prim_path: str = "/newPrim") -> None: +def hide_instruction(target_prim_path: str = "/newPrim", callback: Optional[Callable] = None) -> None: """ Hide and clean up a specific instruction widget. Also cleans up associated timer. """ + + if callback: + callback() + global camera_facing_widget_container, camera_facing_widget_timers if target_prim_path in camera_facing_widget_container: @@ -180,3 +216,36 @@ def hide(target_prim_path: str = "/newPrim") -> None: if target_prim_path in camera_facing_widget_timers: del camera_facing_widget_timers[target_prim_path] + +def update_instruction(target_prim_path: str = "/newPrim", text: str = ""): + """ + Update the text content of an existing instruction widget. + + Args: + target_prim_path (str): The path of the widget to update. + text (str): The new text content to display. + """ + global camera_facing_widget_container + + container_data = camera_facing_widget_container.get(target_prim_path) + if container_data: + container, current_text = container_data + + # Only update if the text has actually changed + if current_text != text: + # Access the widget through the manipulator as shown in ui_container.py + manipulator = container.manipulator + + # The WidgetComponent is stored in the manipulator's components + # Try to access the widget component and then the actual widget + components = getattr(manipulator, '_ComposableManipulator__components') + if len(components) > 0: + simple_text_widget = components[0] + if simple_text_widget and simple_text_widget.component and simple_text_widget.component.widget: + simple_text_widget.component.widget.set_label_text(text) + + # Update the stored text in the global dictionary + camera_facing_widget_container[target_prim_path] = (container, text) + return True + + return False diff --git a/source/isaaclab/isaaclab/ui/xr_widgets/scene_visualization.py b/source/isaaclab/isaaclab/ui/xr_widgets/scene_visualization.py new file mode 100644 index 000000000000..bb571d3c4c30 --- /dev/null +++ b/source/isaaclab/isaaclab/ui/xr_widgets/scene_visualization.py @@ -0,0 +1,607 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations +from typing import Any, Optional, Union +from collections.abc import Callable +import threading +from enum import Enum +import torch +import time +import numpy as np +import inspect +from pxr import Gf +from isaaclab.ui.xr_widgets import show_instruction +from isaaclab.sim import SimulationContext +import omni.log + + + + +class TriggerType(Enum): + """Enumeration of trigger types for visualization callbacks. + + Defines when callbacks should be executed: + - TRIGGER_ON_EVENT: Execute when a specific event occurs + - TRIGGER_ON_PERIOD: Execute at regular time intervals + - TRIGGER_ON_CHANGE: Execute when a specific data variable changes + """ + + TRIGGER_ON_EVENT = 0 + TRIGGER_ON_PERIOD = 1 + TRIGGER_ON_CHANGE = 2 + TRIGGER_ON_UPDATE = 3 + + +class DataCollector: + """Collects and manages data for visualization purposes. + + This class provides a centralized data store for visualization data, + with change detection and callback mechanisms for real-time updates. + """ + + def __init__(self): + """Initialize the data collector with empty data store and callback system.""" + self._data: dict[str, Any] = {} + self._visualization_callback: Optional[Callable] = None + self._changed_flags: set[str] = set() + + def _values_equal(self, existing_value: Any, new_value: Any) -> bool: + """Compare two values using appropriate method based on their types. + + Handles different data types including None, NumPy arrays, PyTorch tensors, + and standard Python types for accurate change detection. + + Args: + existing_value: The current value stored in the data collector + new_value: The new value to compare against + + Returns: + bool: True if values are equal, False otherwise + """ + # If both are None or one is None + if existing_value is None or new_value is None: + return existing_value is new_value + + # If types are different, they're not equal + if type(existing_value) != type(new_value): + return False + + # Handle NumPy arrays + if isinstance(existing_value, np.ndarray): + return np.array_equal(existing_value, new_value) + + # Handle torch tensors (if they exist) + if hasattr(existing_value, 'equal'): + try: + return torch.equal(existing_value, new_value) + except: + pass + + # For all other types (int, float, string, bool, list, dict, set), use regular equality + try: + return existing_value == new_value + except Exception: + # If comparison fails for any reason, assume they're different + return False + + def update_data(self, name: str, value: Any) -> None: + """Update a data field and trigger change detection. + + This method handles data updates with intelligent change detection. + It also performs pre-processing and post-processing based on the field name. + + Args: + name: The name/key of the data field to update + value: The new value to store (None to remove the field) + """ + existing_value = self.get_data(name) + + if value is None: + self._data.pop(name) + if existing_value is not None: + self._changed_flags.add(name) + return + + # Todo: for list or array, the change won't be detected + # Check if the value has changed using appropriate comparison method + if self._values_equal(existing_value, value): + return + + # Save it + self._data[name] = value + self._changed_flags.add(name) + + def update_loop(self) -> None: + """Process pending changes and trigger visualization callbacks. + + This method should be called regularly to ensure visualization updates + are processed in a timely manner. + """ + if len(self._changed_flags) > 0: + if self._visualization_callback: + self._visualization_callback(self._changed_flags) + self._changed_flags.clear() + + def get_data(self, name: str) -> Any: + """Retrieve data by name. + + Args: + name: The name/key of the data field to retrieve + + Returns: + The stored value, or None if the field doesn't exist + """ + return self._data.get(name) + + def set_visualization_callback(self, callback: Callable) -> None: + """Set the VisualizationManager callback function to be called when data changes. + + Args: + callback: Function to call when data changes, receives set of changed field names + """ + self._visualization_callback = callback + + +class VisualizationManager: + """Base class for managing visualization rules and callbacks. + + Provides a framework for registering and executing callbacks based on + different trigger conditions (events, time periods, data changes). + """ + + # Type aliases for different callback signatures + StandardCallback = Callable[['VisualizationManager', 'DataCollector'], None] + EventCallback = Callable[['VisualizationManager', 'DataCollector', Any], None] + CallbackType = Union[StandardCallback, EventCallback] + + class TimeCountdown: + """Internal class for managing periodic timer-based callbacks.""" + period: float + countdown: float + last_time: float + + def __init__(self, period: float, initial_countdown: float = 0.0): + """Initialize a countdown timer. + + Args: + period: Time interval in seconds between callback executions + """ + self.period = period + self.countdown = initial_countdown + self.last_time = time.time() + + def update(self, current_time: float) -> bool: + """Update the countdown timer and check if callback should be triggered. + + Args: + current_time: Current time in seconds + + Returns: + bool: True if callback should be triggered, False otherwise + """ + self.countdown -= (current_time - self.last_time) + self.last_time = current_time + if self.countdown <= 0.0: + self.countdown = self.period + return True + return False + + # Widget presets for common visualization configurations + @classmethod + def message_widget_preset(cls) -> dict[str, Any]: + """Get the message widget preset configuration. + + Returns: + dict: Configuration dictionary for message widgets + """ + return { + "prim_path_source": "/_xr/stage/xrCamera", + "translation": Gf.Vec3f(0, 0, -2), + "display_duration": 3.0, + "max_width": 2.5, + "min_width": 1.0, + "font_size": 0.1, + "text_color": 0xFF00FFFF, + } + + @classmethod + def panel_widget_preset(cls) -> dict[str, Any]: + """Get the panel widget preset configuration. + + Returns: + dict: Configuration dictionary for panel widgets + """ + return { + "prim_path_source": "/XRAnchor", + "translation": Gf.Vec3f(0, 2, 2), # hard-coded temporarily + "display_duration": 0.0, + "font_size": 0.13, + "preferred_width": 2, + "preferred_height": 3, + #"is_billboard": False, + } + + def display_widget(self, text: str, name: str, args: dict[str, Any]) -> None: + """Display a widget with the given text and configuration. + + Args: + text: Text content to display in the widget + name: Unique identifier for the widget. If duplicated, the old one will be removed from scene. + args: Configuration dictionary for widget appearance and behavior + """ + widget_config = args | {"text": text, "target_prim_path": name} + show_instruction(**widget_config) + + def __init__(self, data_collector: DataCollector): + """Initialize the visualization manager. + + Args: + data_collector: DataCollector instance to access the data for visualization use. + """ + self.data_collector: DataCollector = data_collector + data_collector.set_visualization_callback(self.on_change) + + self._rules_on_period: dict[VisualizationManager.TimeCountdown, VisualizationManager.StandardCallback] = {} + self._rules_on_event: dict[str, list[VisualizationManager.EventCallback]] = {} + self._rules_on_change: dict[str, list[VisualizationManager.StandardCallback]] = {} + self._rules_on_update: list[VisualizationManager.StandardCallback] = [] + + # Todo: add support to registering same callbacks for different names + def on_change(self, names: set[str]) -> None: + """Handle data changes by executing registered callbacks. + + Args: + names: Set of data field names that have changed + """ + for name in names: + callbacks = self._rules_on_change.get(name) + if callbacks: + # Create a copy of the list to avoid modification during iteration + for callback in list(callbacks): + callback(self, self.data_collector) + if len(names) > 0: + self.on_event("default_event_has_change") + + def update_loop(self) -> None: + """Update periodic timers and execute callbacks as needed. + + This method should be called regularly to ensure periodic callbacks + are executed at the correct intervals. + """ + + # Create a copy of the list to avoid modification during iteration + for callback in list(self._rules_on_update): + callback(self, self.data_collector) + + current_time = time.time() + # Create a copy of the items to avoid modification during iteration + for timer, callback in list(self._rules_on_period.items()): + triggered = timer.update(current_time) + if triggered: + callback(self, self.data_collector) + + def on_event(self, event: str, params: Any = None) -> None: + """Handle events by executing registered callbacks. + + Args: + event: Name of the event that occurred + """ + callbacks = self._rules_on_event.get(event) + if callbacks is None: + return + # Create a copy of the list to avoid modification during iteration + for callback in list(callbacks): + callback(self, self.data_collector, params) + + # Todo: better organization of callbacks + def register_callback(self, trigger: TriggerType, arg: dict, callback: CallbackType) -> Any: + """Register a callback function to be executed based on trigger conditions. + + Args: + trigger: Type of trigger that should execute the callback + arg: Dictionary containing trigger-specific parameters: + - For TRIGGER_ON_PERIOD: {"period": float} + - For TRIGGER_ON_EVENT: {"event_name": str} + - For TRIGGER_ON_CHANGE: {"variable_name": str} + - For TRIGGER_ON_UPDATE: {} + callback: Function to execute when trigger condition is met + - For TRIGGER_ON_EVENT: callback(manager: VisualizationManager, data_collector: DataCollector, event_params: Any) + - For others: callback(manager: VisualizationManager, data_collector: DataCollector) + + Raises: + TypeError: If callback signature doesn't match the expected signature for the trigger type + """ + # Validate callback signature based on trigger type + self._validate_callback_signature(trigger, callback) + + match trigger: + case TriggerType.TRIGGER_ON_PERIOD: + period = arg.get("period") + initial_countdown = arg.get("initial_countdown", 0.0) + if isinstance(period, float) and isinstance(initial_countdown, float): + timer = VisualizationManager.TimeCountdown(period=period, initial_countdown=initial_countdown) + # Type cast since we've validated the signature + self._rules_on_period[timer] = callback # type: ignore + return timer + case TriggerType.TRIGGER_ON_EVENT: + event = arg.get("event_name") + if isinstance(event, str): + callbacks = self._rules_on_event.get(event) + if callbacks is None: + # Type cast since we've validated the signature + self._rules_on_event[event] = [callback] # type: ignore + else: + # Type cast since we've validated the signature + self._rules_on_event[event].append(callback) # type: ignore + return event + case TriggerType.TRIGGER_ON_CHANGE: + variable_name = arg.get("variable_name") + if isinstance(variable_name, str): + callbacks = self._rules_on_change.get(variable_name) + if callbacks is None: + # Type cast since we've validated the signature + self._rules_on_change[variable_name] = [callback] # type: ignore + else: + # Type cast since we've validated the signature + self._rules_on_change[variable_name].append(callback) # type: ignore + return variable_name + case TriggerType.TRIGGER_ON_UPDATE: + # Type cast since we've validated the signature + self._rules_on_update.append(callback) # type: ignore + return None + + # Todo: better callback-cancel method + def cancel_rule(self, trigger: TriggerType, arg: str | TimeCountdown, callback: Optional[Callable] = None) -> None: + """Remove a previously registered callback. + + Periodic callbacks are not supported to be cancelled for now. + + Args: + trigger: Type of trigger for the callback to remove + arg: Trigger-specific identifier (event name or variable name) + callback: The callback function to remove + """ + callbacks = None + match trigger: + case TriggerType.TRIGGER_ON_CHANGE: + callbacks = self._rules_on_change.get(arg) + case TriggerType.TRIGGER_ON_EVENT: + callbacks = self._rules_on_event.get(arg) + case TriggerType.TRIGGER_ON_PERIOD: + self._rules_on_period.pop(arg) + case TriggerType.TRIGGER_ON_UPDATE: + callbacks = self._rules_on_update + if callbacks is not None: + if callback is not None: + callbacks.remove(callback) + else: + callbacks.clear() + + def set_attr(self, name: str, value: Any) -> None: + """Set an attribute of the visualization manager. + + Args: + name: Name of the attribute to set + value: Value to set the attribute to + """ + setattr(self, name, value) + + def _validate_callback_signature(self, trigger: TriggerType, callback: Callable) -> None: + """Validate that the callback has the correct signature for the trigger type. + + Args: + trigger: Type of trigger for the callback + callback: The callback function to validate + + Raises: + TypeError: If callback signature doesn't match expected signature + """ + try: + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + # Remove 'self' parameter if it's a bound method + if params and params[0].name == 'self': + params = params[1:] + + param_count = len(params) + + if trigger == TriggerType.TRIGGER_ON_EVENT: + # Event callbacks should have 3 parameters: (manager, data_collector, event_params) + expected_count = 3 + expected_sig = "callback(manager: VisualizationManager, data_collector: DataCollector, event_params: Any)" + else: + # Other callbacks should have 2 parameters: (manager, data_collector) + expected_count = 2 + expected_sig = "callback(manager: VisualizationManager, data_collector: DataCollector)" + + if param_count != expected_count: + raise TypeError( + f"Callback for {trigger.name} must have {expected_count} parameters, " + f"but got {param_count}. Expected signature: {expected_sig}. " + f"Actual signature: {sig}" + ) + + except Exception as e: + if isinstance(e, TypeError): + raise + # If we can't inspect the signature (e.g., built-in functions), + # just log a warning and proceed + omni.log.warn(f"Could not validate callback signature for {trigger.name}: {e}") + + + + +class XRVisualization: + """Singleton class providing XR visualization functionality. + + This class implements the singleton pattern to ensure only one instance + of the visualization system exists across the application. It provides + a centralized API for managing XR visualization features. + + When manage a new event ordata field, please add a comment to the following list. + + Event names: + "ik_solver_failed" + + Data fields: + "manipulability_ellipsoid" : list[float] + "device_raw_data" : dict + "joints_distance_percentage_to_limit" : list[float] + "joints_torque" : list[float] + "joints_torque_limit" : list[float] + "joints_name" : list[str] + "wrist_pose" : list[float] + "approximated_working_space" : list[float] + "hand_torque_mapping" : list[str] + """ + + _lock = threading.Lock() + _instance: Optional[XRVisualization] = None + _registered = False + + def __init__(self): + """Prevent direct instantiation.""" + raise RuntimeError("Use VisualizationInterface classmethods instead of direct instantiation") + + @classmethod + def __create_instance(cls, manager: type[VisualizationManager] = VisualizationManager) -> XRVisualization: + """Get the visualization manager instance. + + Returns: + VisualizationManager: The visualization manager instance + """ + with cls._lock: + if cls._instance is None: + # Bypass __init__ by calling __new__ directly + cls._instance = super(XRVisualization, cls).__new__(cls) + cls._instance._initialize(manager) + return cls._instance + + @classmethod + def __get_instance(cls) -> XRVisualization: + """Thread-safe singleton access. + + Returns: + XRVisualization: The singleton instance of the visualization system + """ + if cls._instance is None: + return cls.__create_instance() + elif not cls._instance._registered: + cls._instance._register() + return cls._instance + + def _register(self) -> bool: + """Register the visualization system. + + Returns: + bool: True if the visualization system is registered, False otherwise + """ + if self._registered: + return True + + sim = SimulationContext.instance() + if sim is not None: + sim.add_render_callback("visualization_render_callback", self.update_loop) + self._registered = True + return self._registered + + def _initialize(self, manager: type[VisualizationManager]) -> None: + """Initialize the singleton instance with data collector and visualization manager.""" + + self._data_collector = DataCollector() + self._visualization_manager = manager(self._data_collector) + + self._register() + + self._initialized = True + + # APIs + + def update_loop(self, event) -> None: + """Update the visualization system. + + This method should be called regularly (e.g., every frame) to ensure + visualization updates are processed and periodic callbacks are executed. + """ + self._visualization_manager.update_loop() + self._data_collector.update_loop() + + @classmethod + def push_event(cls, name: str, args: Any = None) -> None: + """Push an event to trigger registered callbacks. + + Args: + name: Name of the event to trigger + args: Optional arguments for the event (currently unused) + """ + instance = cls.__get_instance() + instance._visualization_manager.on_event(name, args) + + @classmethod + def push_data(cls, item: dict[str, Any]) -> None: + """Push data to the visualization system. + + Updates multiple data fields at once. Each key-value pair in the + dictionary will be processed by the data collector. + + Args: + item: Dictionary containing data field names and their values + """ + instance = cls.__get_instance() + for name, value in item.items(): + instance._data_collector.update_data(name, value) + + @classmethod + def set_attrs(cls, attributes: dict[str, Any]) -> None: + """Set configuration data for the visualization system. Not currently used. + + Args: + attributes: Dictionary containing configuration keys and values + """ + + instance = cls.__get_instance() + for name, data in attributes.items(): + instance._visualization_manager.set_attr(name, data) + + @classmethod + def get_attr(cls, name: str) -> Any: + """Get configuration data for the visualization system. Not currently used. + + Args: + name: Configuration key + """ + instance = cls.__get_instance() + return getattr(instance._visualization_manager, name) + + @classmethod + def register_callback(cls, trigger: TriggerType, arg: dict, callback: VisualizationManager.CallbackType) -> None: + """Register a callback function for visualization events. + + Args: + trigger: Type of trigger that should execute the callback + arg: Dictionary containing trigger-specific parameters: + - For TRIGGER_ON_PERIOD: {"period": float} + - For TRIGGER_ON_EVENT: {"event_name": str} + - For TRIGGER_ON_CHANGE: {"variable_name": str} + callback: Function to execute when trigger condition is met + """ + instance = cls.__get_instance() + instance._visualization_manager.register_callback(trigger, arg, callback) + + @classmethod + def assign_manager(cls, manager: type[VisualizationManager]) -> None: + """Assign a visualization manager type to the visualization system. + + Args: + manager: Type of the visualization manager to assign + """ + if cls._instance is not None: + omni.log.error(f"Visualization system already initialized to {type(cls._instance._visualization_manager).__name__}, cannot assign manager {manager.__name__}") + return + + cls.__create_instance(manager) + From f11536b15a7d5b5fdda73c6b05ab06c86151180d Mon Sep 17 00:00:00 2001 From: Lotus Li Date: Tue, 26 Aug 2025 16:28:18 -0700 Subject: [PATCH 2/2] fix attr check --- scripts/tools/teleop_visualization_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/tools/teleop_visualization_manager.py b/scripts/tools/teleop_visualization_manager.py index e87729f710e8..e02621998e99 100644 --- a/scripts/tools/teleop_visualization_manager.py +++ b/scripts/tools/teleop_visualization_manager.py @@ -265,7 +265,7 @@ def _update_torque_skeleton(self, mgr: VisualizationManager, data_collector: Dat hand_torque_mapping : Final = data_collector.get_data("hand_torque_mapping") # enable_torque_color needs to be manually set by calling XRVisualization.set_attrs({"enable_torque_color": True}) - if getattr(mgr, "enable_torque_color") and joints_torque is not None and joints_torque_limit is not None and joints_name is not None and hand_torque_mapping is not None and len(hand_torque_mapping) == len(colors_lines) - 10: + if getattr(mgr, "enable_torque_color", None) and joints_torque is not None and joints_torque_limit is not None and joints_name is not None and hand_torque_mapping is not None and len(hand_torque_mapping) == len(colors_lines) - 10: # Insert empty strings at positions that are not fingers hand_torque_mapping_copy = hand_torque_mapping.copy() hand_torque_mapping_copy.append("")