From 9e725e739834c41f9c3735d702406894eb345206 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 10 Mar 2026 13:07:12 -0700 Subject: [PATCH 01/20] Fix Tiled Camera bug when set_world_poses_from_view is called --- .../isaaclab/sensors/camera/tiled_camera.py | 4 +- .../test/sensors/test_tiled_camera.py | 91 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py index e6ae56adaa53..c45d3405b727 100644 --- a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py @@ -171,7 +171,9 @@ def _initialize_impl(self): self.renderer.prepare_stage(self.stage, self._num_envs) # Create a view for the sensor - self._view = XformPrimView(self.cfg.prim_path, device=self._device, stage=self.stage) + self._view = XformPrimView( + self.cfg.prim_path, device=self._device, stage=self.stage, sync_usd_on_fabric_write=True + ) # Check that sizes are correct if self._view.count != self._num_envs: raise RuntimeError( diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index 0fcf5f26c612..d3ac89b35168 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -1615,6 +1615,97 @@ def test_output_equal_to_usd_camera_intrinsics(setup_camera, device): del camera_usd +@pytest.mark.parametrize( + "camera_cls,cfg_cls", + [(TiledCamera, TiledCameraCfg), (Camera, CameraCfg)], + ids=["tiled", "non_tiled"], +) +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +@pytest.mark.isaacsim_ci +def test_camera_pose_update_reflected_in_render( + setup_camera, device, camera_cls, cfg_cls +): + """Test that moving a camera is reflected in rendered depth. + + Both camera types must produce different depth images when the + camera is repositioned from close to far. + """ + sim, _, dt = setup_camera + cam_cfg = cfg_cls( + prim_path="/World/PoseTestCam", + height=128, + width=256, + update_period=0, + update_latest_camera_pose=True, + data_types=["distance_to_camera"], + spawn=sim_utils.PinholeCameraCfg( + focal_length=24.0, + focus_distance=400.0, + horizontal_aperture=20.955, + clipping_range=(0.1, 1.0e5), + ), + ) + camera = camera_cls(cam_cfg) + + sim.reset() + + target = torch.tensor( + [[0.0, 0.0, 0.0]], + dtype=torch.float32, + device=camera.device, + ) + + # Position A: close to scene objects + eyes_close = torch.tensor( + [[2.0, 2.0, 2.0]], + dtype=torch.float32, + device=camera.device, + ) + camera.set_world_poses_from_view(eyes_close, target) + for _ in range(5): + sim.step() + camera.update(dt) + depth_close = ( + camera.data.output["distance_to_camera"].clone() + ) + + # Position B: far from scene objects + eyes_far = torch.tensor( + [[8.0, 8.0, 8.0]], + dtype=torch.float32, + device=camera.device, + ) + camera.set_world_poses_from_view(eyes_far, target) + for _ in range(2): + sim.step() + camera.update(dt) + depth_far = ( + camera.data.output["distance_to_camera"].clone() + ) + + max_range = cam_cfg.spawn.clipping_range[1] + valid_close = depth_close[depth_close < max_range] + valid_far = depth_far[depth_far < max_range] + + assert valid_close.numel() > 0, ( + "No valid depth pixels from close position" + ) + assert valid_far.numel() > 0, ( + "No valid depth pixels from far position" + ) + + mean_close = valid_close.mean().item() + mean_far = valid_far.mean().item() + + assert mean_far > mean_close * 1.5, ( + f"Mean depth from far ({mean_far:.2f}) should be" + f" >= 1.5x close ({mean_close:.2f}). Renderer may" + " not be observing the updated camera pose." + ) + + del camera + + @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) @pytest.mark.isaacsim_ci def test_sensor_print(setup_camera, device): From a9f61a06fee874cac61f087403e42af741649008 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 10 Mar 2026 13:13:03 -0700 Subject: [PATCH 02/20] Reformat and add comment --- .../isaaclab/sensors/camera/tiled_camera.py | 2 +- .../test/sensors/test_tiled_camera.py | 20 +++++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py index c45d3405b727..c829f302613f 100644 --- a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py @@ -170,7 +170,7 @@ def _initialize_impl(self): # the view keeps references to the prims located in the stage self.renderer.prepare_stage(self.stage, self._num_envs) - # Create a view for the sensor + # Create a view for the sensor with Fabric enabled for fast pose queries, otherwise position will be stale. self._view = XformPrimView( self.cfg.prim_path, device=self._device, stage=self.stage, sync_usd_on_fabric_write=True ) diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index d3ac89b35168..b125b63b548e 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -1622,9 +1622,7 @@ def test_output_equal_to_usd_camera_intrinsics(setup_camera, device): ) @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) @pytest.mark.isaacsim_ci -def test_camera_pose_update_reflected_in_render( - setup_camera, device, camera_cls, cfg_cls -): +def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls, cfg_cls): """Test that moving a camera is reflected in rendered depth. Both camera types must produce different depth images when the @@ -1665,9 +1663,7 @@ def test_camera_pose_update_reflected_in_render( for _ in range(5): sim.step() camera.update(dt) - depth_close = ( - camera.data.output["distance_to_camera"].clone() - ) + depth_close = camera.data.output["distance_to_camera"].clone() # Position B: far from scene objects eyes_far = torch.tensor( @@ -1679,20 +1675,14 @@ def test_camera_pose_update_reflected_in_render( for _ in range(2): sim.step() camera.update(dt) - depth_far = ( - camera.data.output["distance_to_camera"].clone() - ) + depth_far = camera.data.output["distance_to_camera"].clone() max_range = cam_cfg.spawn.clipping_range[1] valid_close = depth_close[depth_close < max_range] valid_far = depth_far[depth_far < max_range] - assert valid_close.numel() > 0, ( - "No valid depth pixels from close position" - ) - assert valid_far.numel() > 0, ( - "No valid depth pixels from far position" - ) + assert valid_close.numel() > 0, "No valid depth pixels from close position" + assert valid_far.numel() > 0, "No valid depth pixels from far position" mean_close = valid_close.mean().item() mean_far = valid_far.mean().item() From 2c0fbc14065d0e892c59a07ba95f4eb996a5fe0d Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 10 Mar 2026 13:14:20 -0700 Subject: [PATCH 03/20] Add missing comment --- source/isaaclab/test/sensors/test_tiled_camera.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index b125b63b548e..8078136b3026 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -1660,6 +1660,10 @@ def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls device=camera.device, ) camera.set_world_poses_from_view(eyes_close, target) + + # Simulate for a few steps + # note: This is a workaround to ensure that the textures are loaded. + # Check "Known Issues" section in the documentation for more details. for _ in range(5): sim.step() camera.update(dt) From 332dd1f4788774a036dc1b4c862aeb65e8450f43 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 10 Mar 2026 19:39:06 -0700 Subject: [PATCH 04/20] Move update usd to the backend --- .../isaaclab/sensors/camera/tiled_camera.py | 4 +++- .../renderers/isaac_rtx_renderer.py | 18 +++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py index c829f302613f..572e14598d5a 100644 --- a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py @@ -172,7 +172,9 @@ def _initialize_impl(self): # Create a view for the sensor with Fabric enabled for fast pose queries, otherwise position will be stale. self._view = XformPrimView( - self.cfg.prim_path, device=self._device, stage=self.stage, sync_usd_on_fabric_write=True + self.cfg.prim_path, + device=self._device, + stage=self.stage, ) # Check that sizes are correct if self._view.count != self._num_envs: diff --git a/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py b/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py index 22b07f13def0..da67e96b2266 100644 --- a/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py +++ b/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py @@ -19,6 +19,7 @@ from isaaclab.app.settings_manager import get_settings_manager from isaaclab.renderers import BaseRenderer +from isaaclab.utils.math import convert_camera_frame_orientation_convention from isaaclab.utils.warp.kernels import reshape_tiled_image from .isaac_rtx_renderer_utils import ensure_isaac_rtx_render_update @@ -178,9 +179,20 @@ def update_camera( orientations: torch.Tensor, intrinsics: torch.Tensor, ): - """No-op for Replicator - uses USD camera prims directly. - See :meth:`~isaaclab.renderers.base_renderer.BaseRenderer.update_camera`.""" - pass + """Write camera poses to USD so Replicator picks up the latest transforms. + + Replicator reads camera transforms from USD prims, not Fabric. + TiledCamera disables ``sync_usd_on_fabric_write`` for performance, so + this method converts world-convention orientations back to the OpenGL + convention expected by USD camera prims and writes them directly. + + See :meth:`~isaaclab.renderers.base_renderer.BaseRenderer.update_camera`. + """ + sensor = render_data.sensor() if render_data.sensor else None + if sensor is None: + return + orientations_opengl = convert_camera_frame_orientation_convention(orientations, origin="world", target="opengl") + sensor._view._set_world_poses_usd(positions, orientations_opengl) def render(self, render_data: IsaacRtxRenderData): """Extract data from annotators and write to output buffers. From 421c53f3b17e9a81efc69d67ad75ba955c53bedc Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 10 Mar 2026 19:47:13 -0700 Subject: [PATCH 05/20] Remove comment --- source/isaaclab/isaaclab/sensors/camera/tiled_camera.py | 1 - 1 file changed, 1 deletion(-) diff --git a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py index 572e14598d5a..4e72cfa615bf 100644 --- a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py @@ -170,7 +170,6 @@ def _initialize_impl(self): # the view keeps references to the prims located in the stage self.renderer.prepare_stage(self.stage, self._num_envs) - # Create a view for the sensor with Fabric enabled for fast pose queries, otherwise position will be stale. self._view = XformPrimView( self.cfg.prim_path, device=self._device, From 450ce70df93b0a1dcbf5465c1b332d284dd02447 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Wed, 25 Mar 2026 08:41:40 -0700 Subject: [PATCH 06/20] Add CPU/GPU tests and re-create matrices --- .../isaaclab/sensors/camera/camera.py | 2 +- .../isaaclab/sim/views/xform_prim_view.py | 36 +++++++++++------- .../test/sensors/test_tiled_camera.py | 38 ++++++++++++++----- .../renderers/isaac_rtx_renderer.py | 15 +------- 4 files changed, 52 insertions(+), 39 deletions(-) diff --git a/source/isaaclab/isaaclab/sensors/camera/camera.py b/source/isaaclab/isaaclab/sensors/camera/camera.py index 70ccc6c14dd8..5470162dceaa 100644 --- a/source/isaaclab/isaaclab/sensors/camera/camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/camera.py @@ -450,7 +450,7 @@ def _initialize_impl(self): super()._initialize_impl() # Create a view for the sensor with Fabric enabled for fast pose queries, otherwise position will be stale. self._view = XformPrimView( - self.cfg.prim_path, device=self._device, stage=self.stage, sync_usd_on_fabric_write=True + self.cfg.prim_path, device=self._device, stage=self.stage, ) # Check that sizes are correct if self._view.count != self._num_envs: diff --git a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py index 211994a7226b..7eb4da13740e 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py +++ b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py @@ -742,8 +742,7 @@ def _set_world_poses_fabric( # Dummy array for scales (not modifying) scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - # Use cached fabricarray for world matrices - world_matrices = self._fabric_world_matrices + world_matrices = self._get_world_matrices_as_fabricarray() # Batch compose matrices with a single kernel launch # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device @@ -813,8 +812,7 @@ def _set_scales_fabric(self, scales: torch.Tensor, indices: Sequence[int] | None positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - # Use cached fabricarray for world matrices - world_matrices = self._fabric_world_matrices + world_matrices = self._get_world_matrices_as_fabricarray() # Batch compose matrices on GPU with a single kernel launch # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device @@ -873,9 +871,7 @@ def _get_world_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) scales_wp = self._fabric_dummy_buffer # Always use dummy for scales - # Use cached fabricarray for world matrices - # This eliminates the 0.06-0.30ms variability from creating fabricarray each call - world_matrices = self._fabric_world_matrices + world_matrices = self._get_world_matrices_as_fabricarray() # Launch GPU kernel to decompose matrices in parallel # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device @@ -946,8 +942,7 @@ def _get_scales_fabric(self, indices: Sequence[int] | None = None) -> torch.Tens positions_wp = self._fabric_dummy_buffer orientations_wp = self._fabric_dummy_buffer - # Use cached fabricarray for world matrices - world_matrices = self._fabric_world_matrices + world_matrices = self._get_world_matrices_as_fabricarray() # Launch GPU kernel to decompose matrices in parallel # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device @@ -1096,11 +1091,6 @@ def _initialize_fabric(self) -> None: # Dummy array for unused outputs (always empty) self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - # Cache fabricarray for world matrices to avoid recreation overhead - # Refs: https://docs.omniverse.nvidia.com/kit/docs/usdrt/latest/docs/usdrt_prim_selection.html - # https://docs.omniverse.nvidia.com/kit/docs/usdrt/latest/docs/scenegraph_use.html - self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") - # Cache Fabric stage to avoid expensive get_current_stage() calls self._fabric_stage = fabric_stage @@ -1112,6 +1102,24 @@ def _initialize_fabric(self) -> None: # made after the view was constructed. self._fabric_usd_sync_done = False + def _get_world_matrices_as_fabricarray(self) -> wp.fabricarray: + """Create a fresh fabricarray for world matrices. + + Recreating both the PrimSelection and fabricarray on each write ensures Fabric's + journaling system marks the attribute as dirty, so downstream consumers (renderers) + observe the update. + """ + import usdrt + + sel = self._fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), + (usdrt.Sdf.ValueTypeNames.Matrix4d, "omni:fabric:worldMatrix", usdrt.Usd.Access.ReadWrite), + ], + device=self._fabric_device, + ) + return wp.fabricarray(sel, "omni:fabric:worldMatrix") + def _sync_fabric_from_usd_once(self) -> None: """Sync Fabric world matrices from USD once, on the first read.""" # Ensure Fabric is initialized diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index 8078136b3026..506e6a8f0635 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -1620,7 +1620,10 @@ def test_output_equal_to_usd_camera_intrinsics(setup_camera, device): [(TiledCamera, TiledCameraCfg), (Camera, CameraCfg)], ids=["tiled", "non_tiled"], ) -@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +@pytest.mark.parametrize("device", [ + "cpu", + "cuda:0", + ]) @pytest.mark.isaacsim_ci def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls, cfg_cls): """Test that moving a camera is reflected in rendered depth. @@ -1628,7 +1631,7 @@ def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls Both camera types must produce different depth images when the camera is repositioned from close to far. """ - sim, _, dt = setup_camera + sim, __unused_camera_cfg, dt = setup_camera cam_cfg = cfg_cls( prim_path="/World/PoseTestCam", height=128, @@ -1660,12 +1663,7 @@ def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls device=camera.device, ) camera.set_world_poses_from_view(eyes_close, target) - - # Simulate for a few steps - # note: This is a workaround to ensure that the textures are loaded. - # Check "Known Issues" section in the documentation for more details. - for _ in range(5): - sim.step() + sim.step() camera.update(dt) depth_close = camera.data.output["distance_to_camera"].clone() @@ -1676,12 +1674,32 @@ def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls device=camera.device, ) camera.set_world_poses_from_view(eyes_far, target) - for _ in range(2): - sim.step() + sim.step() camera.update(dt) depth_far = camera.data.output["distance_to_camera"].clone() max_range = cam_cfg.spawn.clipping_range[1] + + def _save_depth_image(depth: torch.Tensor, filename: str): + """Save a depth tensor as a min-max normalized grayscale PNG for visual comparison.""" + img = depth[0].squeeze().cpu().float() + valid = img[img < max_range] + if valid.numel() > 0: + lo, hi = valid.min(), valid.max() + img = img.clamp(lo, hi) + img = (img - lo) / (hi - lo + 1e-6) + else: + img = img.clamp(0, max_range) / max_range + img = (img * 255).to(torch.uint8).numpy() + from PIL import Image + + Image.fromarray(img, mode="L").save(filename) + print(f"[DEBUG] Saved depth image: {filename}") + + cam_type = "tiled" if camera_cls is TiledCamera else "non_tiled" + _save_depth_image(depth_close, f"/tmp/depth_{cam_type}_{device}_close.png") + _save_depth_image(depth_far, f"/tmp/depth_{cam_type}_{device}_far.png") + valid_close = depth_close[depth_close < max_range] valid_far = depth_far[depth_far < max_range] diff --git a/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py b/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py index da67e96b2266..5cf06657e0b8 100644 --- a/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py +++ b/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py @@ -179,20 +179,7 @@ def update_camera( orientations: torch.Tensor, intrinsics: torch.Tensor, ): - """Write camera poses to USD so Replicator picks up the latest transforms. - - Replicator reads camera transforms from USD prims, not Fabric. - TiledCamera disables ``sync_usd_on_fabric_write`` for performance, so - this method converts world-convention orientations back to the OpenGL - convention expected by USD camera prims and writes them directly. - - See :meth:`~isaaclab.renderers.base_renderer.BaseRenderer.update_camera`. - """ - sensor = render_data.sensor() if render_data.sensor else None - if sensor is None: - return - orientations_opengl = convert_camera_frame_orientation_convention(orientations, origin="world", target="opengl") - sensor._view._set_world_poses_usd(positions, orientations_opengl) + pass def render(self, render_data: IsaacRtxRenderData): """Extract data from annotators and write to output buffers. From 6585962a2b6b5255800abed16e600726d825631e Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Wed, 25 Mar 2026 10:52:46 -0700 Subject: [PATCH 07/20] Enable Fabric CPU path --- .../isaaclab/isaaclab/sim/views/xform_prim_view.py | 13 +------------ source/isaaclab/test/sensors/test_tiled_camera.py | 2 +- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py index 7eb4da13740e..e6f2927e4227 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py +++ b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py @@ -140,22 +140,11 @@ def __init__( settings = SettingsManager.instance() self._use_fabric = bool(settings.get("/physics/fabricEnabled", False)) - # Check for unsupported Fabric + CPU combination - if self._use_fabric and self._device == "cpu": - logger.warning( - "Fabric mode with Warp fabric-array operations is not supported on CPU devices. " - "While Fabric itself can run on both CPU and GPU, our batch Warp kernels for " - "fabric-array operations require CUDA and are not reliable on the CPU backend. " - "To ensure stability, Fabric is being disabled and execution will fall back " - "to standard USD operations on the CPU. This may impact performance." - ) - self._use_fabric = False - # Check for unsupported Fabric + non-primary CUDA device combination. # USDRT SelectPrims and Warp fabric arrays only support cuda:0 internally. # When running on cuda:1 or higher, SelectPrims raises a C++ error regardless of # the device argument, because USDRT uses the active CUDA context (which is cuda:1). - if self._use_fabric and self._device not in ("cuda", "cuda:0"): + if self._use_fabric and self._device not in ("cpu", "cuda", "cuda:0"): logger.warning( f"Fabric mode is not supported on device '{self._device}'. " "USDRT SelectPrims and Warp fabric arrays only support cuda:0. " diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index 506e6a8f0635..a09f766bd628 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -1693,7 +1693,7 @@ def _save_depth_image(depth: torch.Tensor, filename: str): img = (img * 255).to(torch.uint8).numpy() from PIL import Image - Image.fromarray(img, mode="L").save(filename) + Image.fromarray(img).save(filename) print(f"[DEBUG] Saved depth image: {filename}") cam_type = "tiled" if camera_cls is TiledCamera else "non_tiled" From 4a17b9c1a2cf8c40cc7da150903b21badc59ef18 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Mon, 30 Mar 2026 17:02:29 -0700 Subject: [PATCH 08/20] Update unit tests to do not skip CPU path --- source/isaaclab/test/sim/test_views_xform_prim.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/isaaclab/test/sim/test_views_xform_prim.py b/source/isaaclab/test/sim/test_views_xform_prim.py index 3de7a0b357a2..2de28b78b5a0 100644 --- a/source/isaaclab/test/sim/test_views_xform_prim.py +++ b/source/isaaclab/test/sim/test_views_xform_prim.py @@ -62,8 +62,6 @@ def _skip_if_backend_unavailable(backend: str, device: str): """Skip tests when the requested backend is unavailable.""" if device.startswith("cuda") and not torch.cuda.is_available(): pytest.skip("CUDA not available") - if backend == "fabric" and device == "cpu": - pytest.skip("Warp fabricarray operations on CPU have known issues") def _prim_type_for_backend(backend: str) -> str: From ca44dda4db111ac3a840ef6021e84a99492faae0 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 31 Mar 2026 13:36:19 -0700 Subject: [PATCH 09/20] Refactoring to use fabric local matrix --- .../isaaclab/sim/views/xform_prim_view.py | 308 ++++++++++++++---- 1 file changed, 237 insertions(+), 71 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py index e6f2927e4227..8c26085f6c1b 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py +++ b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py @@ -80,6 +80,14 @@ class XformPrimView: time-sampled keyframes separately. """ + # -- Fabric attribute names -- + _WORLD_MATRIX_ATTR = "omni:fabric:worldMatrix" + _LOCAL_MATRIX_ATTR = "omni:fabric:localMatrix" + + _shared_fabric_hierarchy = None + _shared_fabric_stage_key: int | None = None + _world_xforms_dirty: bool = False + def __init__( self, prim_path: str, @@ -163,14 +171,15 @@ def __init__( # Fabric batch infrastructure (initialized lazily on first use) self._fabric_initialized = False self._fabric_usd_sync_done = False - self._fabric_selection = None - self._fabric_to_view: wp.array | None = None self._view_to_fabric: wp.array | None = None self._default_view_indices: wp.array | None = None self._fabric_hierarchy = None # Create a valid USD attribute name: namespace:name # Use "isaaclab" namespace to identify our custom attributes self._view_index_attr = f"isaaclab:view_index:{abs(hash(self))}" + + self._world_selection = None + self._local_selection = None """ Properties. @@ -379,31 +388,27 @@ def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: """Get local-space poses for prims in the view. - This method retrieves the position and orientation of each prim in local space (relative to - their parent prims). It reads directly from USD's ``xformOp:translate`` and ``xformOp:orient`` attributes. - - Note: - Even in Fabric mode, local pose operations use USD. This behavior is based on Isaac Sim's design - where Fabric is only used for world pose operations. + This method retrieves the position and orientation of each prim in local space + (relative to their parent prims). - Rationale: - - Local pose reads need correct parent-child hierarchy relationships - - USD maintains these relationships correctly and efficiently - - Fabric is optimized for world pose operations, not local hierarchies + When Fabric is enabled, reads ``omni:fabric:localMatrix`` and decomposes it + using GPU batch operations. Otherwise reads USD's ``xformOp:translate`` and + ``xformOp:orient`` via an :class:`UsdGeom.XformCache`. Note: Scale is ignored. The returned poses contain only translation and rotation. Args: - indices: Indices of prims to get poses for. Defaults to None, in which case poses are retrieved - for all prims in the view. + indices: Indices of prims to get poses for. Defaults to None, in which + case poses are retrieved for all prims in the view. Returns: A tuple of (translations, orientations) where: - - translations: Torch tensor of shape (M, 3) containing local-space translations (x, y, z), - where M is the number of prims queried. - - orientations: Torch tensor of shape (M, 4) containing local-space quaternions (w, x, y, z) + - translations: Torch tensor of shape (M, 3) containing local-space + translations (x, y, z), where M is the number of prims queried. + - orientations: Torch tensor of shape (M, 4) containing local-space + quaternions (x, y, z, w). """ if self._use_fabric: return self._get_local_poses_fabric(indices) @@ -711,6 +716,11 @@ def _set_world_poses_fabric( # Lazy initialization if not self._fabric_initialized: self._initialize_fabric() + # Flush pending local→world propagation so the explicit world write + # is not later overwritten by a deferred update_world_xforms() cascade. + if XformPrimView._world_xforms_dirty: + self._fabric_hierarchy.update_world_xforms() + XformPrimView._world_xforms_dirty = False # Resolve indices (treat slice(None) as None for consistency with USD path) indices_wp = self._resolve_indices_wp(indices) @@ -755,8 +765,6 @@ def _set_world_poses_fabric( # Synchronize to ensure kernel completes wp.synchronize() - # Update world transforms within Fabric hierarchy - self._fabric_hierarchy.update_world_xforms() # Fabric now has authoritative data; skip future USD syncs self._fabric_usd_sync_done = True # Mirror to USD for renderer-facing prims when enabled. @@ -771,17 +779,57 @@ def _set_local_poses_fabric( orientations: torch.Tensor | None = None, indices: Sequence[int] | None = None, ): - """Set local poses using USD (matches Isaac Sim's design). + """Set local poses using Fabric GPU batch operations. - Note: Even in Fabric mode, local pose operations use USD. - This is Isaac Sim's design: the ``usd=False`` parameter only affects world poses. - - Rationale: - - Local pose writes need correct parent-child hierarchy relationships - - USD maintains these relationships correctly and efficiently - - Fabric is optimized for world pose operations, not local hierarchies + Composes ``omni:fabric:localMatrix`` via a Warp kernel (GPU-batched, + handles partial updates), then re-registers each matrix through + :meth:`IFabricHierarchy.set_local_xform` so that Fabric's change + tracking picks it up. Finally calls ``update_world_xforms()`` which + propagates world matrices for every prim whose ``localMatrix`` changed. """ - self._set_local_poses_usd(translations, orientations, indices) + if not self._fabric_initialized: + self._initialize_fabric() + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + if translations is not None: + translations_wp = wp.from_torch(translations) + else: + translations_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + + if orientations is not None: + orientations_wp = wp.from_torch(orientations) + else: + orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) + + scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + + local_matrices = self._get_local_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, + dim=count, + inputs=[ + local_matrices, + translations_wp, + orientations_wp, + scales_wp, + False, + False, + False, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + + wp.synchronize() + + XformPrimView._world_xforms_dirty = True + self._fabric_usd_sync_done = True + if self._sync_usd_on_fabric_write: + self._set_local_poses_usd(translations, orientations, indices) def _set_scales_fabric(self, scales: torch.Tensor, indices: Sequence[int] | None = None): """Set scales using Fabric GPU batch operations.""" @@ -825,8 +873,6 @@ def _set_scales_fabric(self, scales: torch.Tensor, indices: Sequence[int] | None # Synchronize to ensure kernel completes before syncing wp.synchronize() - # Update world transforms to propagate changes - self._fabric_hierarchy.update_world_xforms() # Fabric now has authoritative data; skip future USD syncs self._fabric_usd_sync_done = True # Mirror to USD for renderer-facing prims when enabled. @@ -841,6 +887,10 @@ def _get_world_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple # Sync once from USD to ensure reads see the latest authored transforms if not self._fabric_usd_sync_done: self._sync_fabric_from_usd_once() + # Propagate local→world if any local poses were modified since last read + if XformPrimView._world_xforms_dirty: + self._fabric_hierarchy.update_world_xforms() + XformPrimView._world_xforms_dirty = False # Resolve indices (treat slice(None) as None for consistency with USD path) indices_wp = self._resolve_indices_wp(indices) @@ -891,18 +941,52 @@ def _get_world_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple return positions, orientations def _get_local_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get local poses using USD (matches Isaac Sim's design). + """Get local poses from Fabric using GPU batch operations. - Note: - Even in Fabric mode, local pose operations use USD's XformCache. - This is Isaac Sim's design: the ``usd=False`` parameter only affects world poses. - - Rationale: - - Local pose computation requires parent transforms which may not be in the view - - USD's XformCache provides efficient hierarchy-aware local transform queries - - Fabric is optimized for world pose operations, not local hierarchies + Reads ``omni:fabric:localMatrix`` and decomposes each 4x4 matrix into + translation and orientation, mirroring the world-pose Fabric path. """ - return self._get_local_poses_usd(indices) + if not self._fabric_initialized: + self._initialize_fabric() + if not self._fabric_usd_sync_done: + self._sync_fabric_from_usd_once() + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + use_cached_buffers = indices is None or indices == slice(None) + if use_cached_buffers: + translations_wp = self._fabric_local_translations_buffer + orientations_wp = self._fabric_local_orientations_buffer + scales_wp = self._fabric_dummy_buffer + else: + translations_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) + orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) + scales_wp = self._fabric_dummy_buffer + + local_matrices = self._get_local_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + dim=count, + inputs=[ + local_matrices, + translations_wp, + orientations_wp, + scales_wp, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + + if use_cached_buffers: + wp.synchronize() + return self._fabric_local_translations_torch, self._fabric_local_orientations_torch + else: + translations = wp.to_torch(translations_wp) + orientations = wp.to_torch(orientations_wp) + return translations, orientations def _get_scales_fabric(self, indices: Sequence[int] | None = None) -> torch.Tensor: """Get scales from Fabric using GPU batch operations.""" @@ -985,29 +1069,38 @@ def _initialize_fabric(self) -> None: # when Fabric Scene Delegate is enabled, but we ensure they exist for i in range(self.count): rt_prim = fabric_stage.GetPrimAtPath(self.prim_paths[i]) - rt_xformable = Rt.Xformable(rt_prim) - - # Create Fabric hierarchy world matrix attribute if it doesn't exist - has_attr = ( - rt_xformable.HasFabricHierarchyWorldMatrixAttr() - if hasattr(rt_xformable, "HasFabricHierarchyWorldMatrixAttr") - else False - ) - if not has_attr: - rt_xformable.CreateFabricHierarchyWorldMatrixAttr() - - # Best-effort USD->Fabric sync; authoritative initialization happens on first read. - rt_xformable.SetWorldXformFromUsd() # Create view index attribute for batch operations rt_prim.CreateAttribute(self._view_index_attr, usdrt.Sdf.ValueTypeNames.UInt, custom=True) rt_prim.GetAttribute(self._view_index_attr).Set(i) - # After syncing all prims, update the Fabric hierarchy to ensure world matrices are computed - self._fabric_hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( - fabric_stage.GetFabricId(), fabric_stage.GetStageIdAsStageId() - ) - self._fabric_hierarchy.update_world_xforms() + # Share a single hierarchy handle across all instances; rebuild when stage changes. + if XformPrimView._shared_fabric_hierarchy is None or XformPrimView._shared_fabric_stage_key != stage_id: + # Populate Fabric from USD once to establish parent-child connectivity. + # Without this, IFabricHierarchy cannot traverse the hierarchy and + # update_world_xforms() will not propagate parent transforms to children. + pop = usdrt.population.IUtils() + pop.set_enable_usd_notice_handling( + fabric_stage.GetStageIdAsStageId(), fabric_stage.GetFabricId(), True + ) + pop.populate_from_usd( + fabric_stage.GetStageReaderWriterId(), + fabric_stage.GetStageIdAsStageId(), + usdrt.Sdf.Path("/"), + 0, + ) + pop.apply_pending_usd_updates( + fabric_stage.GetStageIdAsStageId(), fabric_stage.GetStageReaderWriterId(), 0 + ) + + XformPrimView._shared_fabric_hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( + fabric_stage.GetFabricId(), fabric_stage.GetStageIdAsStageId() + ) + XformPrimView._shared_fabric_hierarchy.update_world_xforms() + XformPrimView._shared_fabric_hierarchy.track_local_xform_changes(True) + XformPrimView._shared_fabric_hierarchy.track_world_xform_changes(True) + XformPrimView._shared_fabric_stage_key = stage_id + self._fabric_hierarchy = XformPrimView._shared_fabric_hierarchy # Step 2: Create index arrays for batch operations self._default_view_indices = wp.zeros((self.count,), dtype=wp.uint32).to(self._device) @@ -1044,10 +1137,9 @@ def _initialize_fabric(self) -> None: ) fabric_device = "cuda:0" - self._fabric_selection = fabric_stage.SelectPrims( + index_selection = fabric_stage.SelectPrims( require_attrs=[ (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, "omni:fabric:worldMatrix", usdrt.Usd.Access.ReadWrite), ], device=fabric_device, ) @@ -1056,12 +1148,12 @@ def _initialize_fabric(self) -> None: # Note: fabric_to_view is tied to fabric_device (cuda:0) because it's created from SelectPrims. # view_to_fabric must also be on fabric_device since it's always used with fabricarrays in kernels. self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32).to(fabric_device) - self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr) + fabric_to_view = wp.fabricarray(index_selection, self._view_index_attr) wp.launch( kernel=fabric_utils.set_view_to_fabric_array, - dim=self._fabric_to_view.shape[0], - inputs=[self._fabric_to_view, self._view_to_fabric], + dim=fabric_to_view.shape[0], + inputs=[fabric_to_view, self._view_to_fabric], device=fabric_device, ) # Synchronize to ensure mapping is ready before any operations @@ -1077,9 +1169,24 @@ def _initialize_fabric(self) -> None: self._fabric_orientations_buffer = wp.from_torch(self._fabric_orientations_torch, dtype=wp.float32) self._fabric_scales_buffer = wp.from_torch(self._fabric_scales_torch, dtype=wp.float32) + # Pre-allocate reusable output buffers for local pose reads + self._fabric_local_translations_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) + self._fabric_local_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) + + self._fabric_local_translations_buffer = wp.from_torch( + self._fabric_local_translations_torch, dtype=wp.float32 + ) + self._fabric_local_orientations_buffer = wp.from_torch( + self._fabric_local_orientations_torch, dtype=wp.float32 + ) + # Dummy array for unused outputs (always empty) self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + # Cached selection / fabricarray for local matrices (lazy-initialized) + self._local_selection = None + self._fabric_local_array = None + # Cache Fabric stage to avoid expensive get_current_stage() calls self._fabric_stage = fabric_stage @@ -1100,14 +1207,45 @@ def _get_world_matrices_as_fabricarray(self) -> wp.fabricarray: """ import usdrt - sel = self._fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, "omni:fabric:worldMatrix", usdrt.Usd.Access.ReadWrite), - ], - device=self._fabric_device, - ) - return wp.fabricarray(sel, "omni:fabric:worldMatrix") + # Full rebuild of select prim, there is a bug in PrepareForReuse that + # causes CPU buffers to take precedence over GPU buffers. + if True: + self._world_selection = self._fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), + (usdrt.Sdf.ValueTypeNames.Matrix4d, self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), + ], + device=self._fabric_device, + ) + self._fabric_array = wp.fabricarray(self._world_selection, self._WORLD_MATRIX_ATTR) + else: + self._world_selection.PrepareForReuse() + + return self._fabric_array + + def _get_local_matrices_as_fabricarray(self) -> wp.fabricarray: + """Create a fresh fabricarray for local matrices. + + Mirrors :meth:`_get_world_matrices_as_fabricarray` but targets + ``omni:fabric:localMatrix``. + """ + import usdrt + + # Full rebuild of select prim, there is a bug in PrepareForReuse that + # causes CPU buffers to take precedence over GPU buffers. + if True: + self._local_selection = self._fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), + (usdrt.Sdf.ValueTypeNames.Matrix4d, self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), + ], + device=self._fabric_device, + ) + self._fabric_local_array = wp.fabricarray(self._local_selection, self._LOCAL_MATRIX_ATTR) + else: + self._local_selection.PrepareForReuse() + + return self._fabric_local_array def _sync_fabric_from_usd_once(self) -> None: """Sync Fabric world matrices from USD once, on the first read.""" @@ -1125,6 +1263,34 @@ def _sync_fabric_from_usd_once(self) -> None: self._set_scales_fabric(scales_usd) self._sync_usd_on_fabric_write = prev_sync + # Also sync local matrices so _get_local_poses_fabric() returns + # correct values. Write directly to localMatrix without calling + # update_world_xforms() — the world matrices above are already correct. + local_trans_usd, local_orient_usd = self._get_local_poses_usd() + local_trans_wp = wp.from_torch(local_trans_usd) + local_orient_wp = wp.from_torch(local_orient_usd) + scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + indices_wp = self._resolve_indices_wp(None) + local_matrices = self._get_local_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, + dim=self.count, + inputs=[ + local_matrices, + local_trans_wp, + local_orient_wp, + scales_wp, + False, + False, + False, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + wp.synchronize() + self._fabric_usd_sync_done = True def _resolve_indices_wp(self, indices: Sequence[int] | None) -> wp.array: From 167db07b57354d51a38d025cf8370c6c69103c9e Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 31 Mar 2026 16:15:52 -0700 Subject: [PATCH 10/20] Break down XFormPrimView into Usd and Fabric backends --- .../isaaclab/isaaclab/sim/views/__init__.pyi | 6 + .../isaaclab/sim/views/xform_backend.py | 110 ++ .../sim/views/xform_fabric_backend.py | 510 +++++++++ .../isaaclab/sim/views/xform_prim_view.py | 965 +----------------- .../isaaclab/sim/views/xform_usd_backend.py | 220 ++++ 5 files changed, 890 insertions(+), 921 deletions(-) create mode 100644 source/isaaclab/isaaclab/sim/views/xform_backend.py create mode 100644 source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py create mode 100644 source/isaaclab/isaaclab/sim/views/xform_usd_backend.py diff --git a/source/isaaclab/isaaclab/sim/views/__init__.pyi b/source/isaaclab/isaaclab/sim/views/__init__.pyi index a666958e4387..76864b225a6e 100644 --- a/source/isaaclab/isaaclab/sim/views/__init__.pyi +++ b/source/isaaclab/isaaclab/sim/views/__init__.pyi @@ -4,7 +4,13 @@ # SPDX-License-Identifier: BSD-3-Clause __all__ = [ + "FabricBackend", + "UsdBackend", + "XformBackend", "XformPrimView", ] +from .xform_backend import XformBackend +from .xform_fabric_backend import FabricBackend from .xform_prim_view import XformPrimView +from .xform_usd_backend import UsdBackend diff --git a/source/isaaclab/isaaclab/sim/views/xform_backend.py b/source/isaaclab/isaaclab/sim/views/xform_backend.py new file mode 100644 index 000000000000..e173d3764647 --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_backend.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Protocol + +import torch + + +class XformBackend(Protocol): + """Protocol defining the interface for :class:`XformPrimView` transform backends. + + Implementations provide read/write access to prim transforms through either + the USD or Fabric data path. :class:`XformPrimView` delegates all transform + operations to a *primary* backend and optionally replicates writes to one or + more *sync* backends. + """ + + def initialize(self) -> None: + """Perform any deferred initialisation required by the backend.""" + ... + + def set_world_poses( + self, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set world-space poses for the managed prims. + + Args: + positions: World-space positions, shape ``(M, 3)`` [m]. + orientations: World-space quaternions ``(x, y, z, w)``, shape ``(M, 4)``. + indices: Subset of prim indices to update. ``None`` means all. + """ + ... + + def get_world_poses( + self, + indices: Sequence[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Return world-space ``(positions, orientations)`` for the managed prims. + + Args: + indices: Subset of prim indices to query. ``None`` means all. + + Returns: + ``(positions, orientations)`` with shapes ``(M, 3)`` and ``(M, 4)``. + """ + ... + + def set_local_poses( + self, + translations: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set local-space poses (relative to each prim's parent). + + Args: + translations: Local-space translations, shape ``(M, 3)`` [m]. + orientations: Local-space quaternions ``(x, y, z, w)``, shape ``(M, 4)``. + indices: Subset of prim indices to update. ``None`` means all. + """ + ... + + def get_local_poses( + self, + indices: Sequence[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Return local-space ``(translations, orientations)`` for the managed prims. + + Args: + indices: Subset of prim indices to query. ``None`` means all. + + Returns: + ``(translations, orientations)`` with shapes ``(M, 3)`` and ``(M, 4)``. + """ + ... + + def set_scales( + self, + scales: torch.Tensor, + indices: Sequence[int] | None = None, + ) -> None: + """Set scales for the managed prims. + + Args: + scales: Scales, shape ``(M, 3)``. + indices: Subset of prim indices to update. ``None`` means all. + """ + ... + + def get_scales( + self, + indices: Sequence[int] | None = None, + ) -> torch.Tensor: + """Return scales for the managed prims. + + Args: + indices: Subset of prim indices to query. ``None`` means all. + + Returns: + Scales tensor of shape ``(M, 3)``. + """ + ... diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py new file mode 100644 index 000000000000..5372de458627 --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -0,0 +1,510 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import logging +from collections.abc import Sequence + +import torch +import warp as wp + +from pxr import Usd + +import isaaclab.sim as sim_utils +from isaaclab.utils.warp import fabric as fabric_utils + +logger = logging.getLogger(__name__) + + +class FabricBackend: + """Fabric-based transform backend for :class:`XformPrimView`. + + Uses NVIDIA's Fabric API with Warp GPU kernels for high-performance batch + transform operations. + """ + + _WORLD_MATRIX_ATTR = "omni:fabric:worldMatrix" + _LOCAL_MATRIX_ATTR = "omni:fabric:localMatrix" + + _hierarchy_cache: dict[int, object] = {} + _dirty_stages: set[int] = set() + + def __init__( + self, + prims: list[Usd.Prim], + device: str, + ): + self._prims = prims + self._device = device + self._view_index_attr = f"isaaclab:view_index:{abs(id(self))}" + + # Lazy-initialized state (None until initialize() runs) + self._view_to_fabric: wp.array | None = None + self._default_view_indices: wp.array | None = None + self._fabric_hierarchy = None + self._world_selection = None + self._local_selection = None + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def count(self) -> int: + """Number of prims managed by this backend.""" + return len(self._prims) + + @property + def prim_paths(self) -> list[str]: + """Prim path strings (lazily cached).""" + if not hasattr(self, "_prim_paths_cache"): + self._prim_paths_cache = [p.GetPath().pathString for p in self._prims] + return self._prim_paths_cache + + # ------------------------------------------------------------------ + # Initialization + # ------------------------------------------------------------------ + + def initialize(self) -> None: + """Set up Fabric batch infrastructure for GPU-accelerated pose queries. + + Idempotent — subsequent calls after the first are no-ops. + + Ensures all prims have the required Fabric hierarchy attributes + (``omni:fabric:localMatrix`` and ``omni:fabric:worldMatrix``) and + creates the index mapping, selections, and pre-allocated buffers + needed for Warp kernel launches. + """ + if self._fabric_hierarchy is not None: + return + import usdrt + from usdrt import Rt # noqa: F401 — imported for side-effects + + stage_id = sim_utils.get_current_stage_id() + fabric_stage = usdrt.Usd.Stage.Attach(stage_id) + + # Ensure every prim carries the view-index attribute + for i in range(self.count): + rt_prim = fabric_stage.GetPrimAtPath(self.prim_paths[i]) + rt_prim.CreateAttribute(self._view_index_attr, usdrt.Sdf.ValueTypeNames.UInt, custom=True) + rt_prim.GetAttribute(self._view_index_attr).Set(i) + + # Reuse (or create) a hierarchy handle for this stage. + if stage_id not in FabricBackend._hierarchy_cache: + pop = usdrt.population.IUtils() + pop.set_enable_usd_notice_handling( + fabric_stage.GetStageIdAsStageId(), + fabric_stage.GetFabricId(), + True, + ) + pop.populate_from_usd( + fabric_stage.GetStageReaderWriterId(), + fabric_stage.GetStageIdAsStageId(), + usdrt.Sdf.Path("/"), + 0, + ) + pop.apply_pending_usd_updates( + fabric_stage.GetStageIdAsStageId(), + fabric_stage.GetStageReaderWriterId(), + 0, + ) + + hierarchy = ( + usdrt.hierarchy.IFabricHierarchy() + .get_fabric_hierarchy( + fabric_stage.GetFabricId(), + fabric_stage.GetStageIdAsStageId(), + ) + ) + hierarchy.update_world_xforms() + hierarchy.track_local_xform_changes(True) + hierarchy.track_world_xform_changes(True) + FabricBackend._hierarchy_cache[stage_id] = hierarchy + + self._fabric_hierarchy = FabricBackend._hierarchy_cache[stage_id] + self._stage_id = stage_id + + # Default view-index array (0 … count-1) + self._default_view_indices = wp.zeros((self.count,), dtype=wp.uint32).to(self._device) + wp.launch( + kernel=fabric_utils.arange_k, + dim=self.count, + inputs=[self._default_view_indices], + device=self._device, + ) + wp.synchronize() + + # Resolve the Fabric device string (SelectPrims only supports cuda:0) + fabric_device = self._device + if self._device == "cuda": + logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.") + fabric_device = "cuda:0" + elif self._device.startswith("cuda:"): + if self._device != "cuda:0": + logger.debug( + f"SelectPrims only supports cuda:0. Using cuda:0 for SelectPrims " + f"even though simulation device is {self._device}." + ) + fabric_device = "cuda:0" + + # Build the bidirectional view ↔ fabric index mapping + index_selection = fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), + ], + device=fabric_device, + ) + + self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32).to(fabric_device) + fabric_to_view = wp.fabricarray(index_selection, self._view_index_attr) + + wp.launch( + kernel=fabric_utils.set_view_to_fabric_array, + dim=fabric_to_view.shape[0], + inputs=[fabric_to_view, self._view_to_fabric], + device=fabric_device, + ) + wp.synchronize() + + # Pre-allocate reusable output buffers (world poses) + self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) + self._fabric_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) + self._fabric_scales_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) + + self._fabric_positions_buffer = wp.from_torch(self._fabric_positions_torch, dtype=wp.float32) + self._fabric_orientations_buffer = wp.from_torch(self._fabric_orientations_torch, dtype=wp.float32) + self._fabric_scales_buffer = wp.from_torch(self._fabric_scales_torch, dtype=wp.float32) + + # Pre-allocate reusable output buffers (local poses) + self._fabric_local_translations_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) + self._fabric_local_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) + + self._fabric_local_translations_buffer = wp.from_torch( + self._fabric_local_translations_torch, dtype=wp.float32 + ) + self._fabric_local_orientations_buffer = wp.from_torch( + self._fabric_local_orientations_torch, dtype=wp.float32 + ) + + # Dummy buffer for unused kernel outputs (always empty) + self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + + self._local_selection = None + self._fabric_local_array = None + + self._fabric_stage = fabric_stage + self._fabric_device = fabric_device + + # ------------------------------------------------------------------ + # Setters + # ------------------------------------------------------------------ + + def set_world_poses( + self, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Write world poses to Fabric ``omni:fabric:worldMatrix`` via a Warp kernel.""" + self.initialize() + + if self._stage_id in FabricBackend._dirty_stages: + self._fabric_hierarchy.update_world_xforms() + FabricBackend._dirty_stages.discard(self._stage_id) + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + if positions is not None: + positions_wp = wp.from_torch(positions) + else: + positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + + if orientations is not None: + orientations_wp = wp.from_torch(orientations) + else: + orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) + + scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + world_matrices = self._get_world_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, + dim=count, + inputs=[ + world_matrices, + positions_wp, + orientations_wp, + scales_wp, + False, + False, + False, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + wp.synchronize() + + def set_local_poses( + self, + translations: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Write local poses to Fabric ``omni:fabric:localMatrix`` via a Warp kernel. + + After composing the local matrix the method re-registers it through + :pyobj:`IFabricHierarchy` and marks world transforms as dirty so that a + subsequent read will propagate the change. + """ + self.initialize() + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + if translations is not None: + translations_wp = wp.from_torch(translations) + else: + translations_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + + if orientations is not None: + orientations_wp = wp.from_torch(orientations) + else: + orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) + + scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + local_matrices = self._get_local_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, + dim=count, + inputs=[ + local_matrices, + translations_wp, + orientations_wp, + scales_wp, + False, + False, + False, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + wp.synchronize() + + FabricBackend._dirty_stages.add(self._stage_id) + + def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) -> None: + """Write scales into the Fabric world matrix via a Warp kernel.""" + self.initialize() + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + scales_wp = wp.from_torch(scales) + positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) + world_matrices = self._get_world_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, + dim=count, + inputs=[ + world_matrices, + positions_wp, + orientations_wp, + scales_wp, + False, + False, + False, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + wp.synchronize() + + # ------------------------------------------------------------------ + # Getters + # ------------------------------------------------------------------ + + def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Read world poses from Fabric and decompose via a Warp kernel.""" + self.initialize() + if self._stage_id in FabricBackend._dirty_stages: + self._fabric_hierarchy.update_world_xforms() + FabricBackend._dirty_stages.discard(self._stage_id) + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + use_cached_buffers = indices is None or indices == slice(None) + if use_cached_buffers: + positions_wp = self._fabric_positions_buffer + orientations_wp = self._fabric_orientations_buffer + scales_wp = self._fabric_dummy_buffer + else: + positions_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) + orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) + scales_wp = self._fabric_dummy_buffer + + world_matrices = self._get_world_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + dim=count, + inputs=[ + world_matrices, + positions_wp, + orientations_wp, + scales_wp, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + + if use_cached_buffers: + wp.synchronize() + return self._fabric_positions_torch, self._fabric_orientations_torch + else: + positions = wp.to_torch(positions_wp) + orientations = wp.to_torch(orientations_wp) + return positions, orientations + + def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Read local poses from Fabric and decompose via a Warp kernel.""" + self.initialize() + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + use_cached_buffers = indices is None or indices == slice(None) + if use_cached_buffers: + translations_wp = self._fabric_local_translations_buffer + orientations_wp = self._fabric_local_orientations_buffer + scales_wp = self._fabric_dummy_buffer + else: + translations_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) + orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) + scales_wp = self._fabric_dummy_buffer + + local_matrices = self._get_local_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + dim=count, + inputs=[ + local_matrices, + translations_wp, + orientations_wp, + scales_wp, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + + if use_cached_buffers: + wp.synchronize() + return self._fabric_local_translations_torch, self._fabric_local_orientations_torch + else: + translations = wp.to_torch(translations_wp) + orientations = wp.to_torch(orientations_wp) + return translations, orientations + + def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: + """Read scales from Fabric world matrices and extract via a Warp kernel.""" + self.initialize() + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + use_cached_buffers = indices is None or indices == slice(None) + if use_cached_buffers: + scales_wp = self._fabric_scales_buffer + else: + scales_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) + + positions_wp = self._fabric_dummy_buffer + orientations_wp = self._fabric_dummy_buffer + world_matrices = self._get_world_matrices_as_fabricarray() + + wp.launch( + kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + dim=count, + inputs=[ + world_matrices, + positions_wp, + orientations_wp, + scales_wp, + indices_wp, + self._view_to_fabric, + ], + device=self._fabric_device, + ) + + if use_cached_buffers: + wp.synchronize() + return self._fabric_scales_torch + else: + return wp.to_torch(scales_wp) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _get_world_matrices_as_fabricarray(self) -> wp.fabricarray: + """Create a fresh :class:`wp.fabricarray` backed by ``omni:fabric:worldMatrix``. + + Rebuilding both the :class:`PrimSelection` and fabricarray each call + ensures Fabric's journaling marks the attribute dirty for downstream + consumers (renderers, etc.). + """ + import usdrt + + if True: + self._world_selection = self._fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), + (usdrt.Sdf.ValueTypeNames.Matrix4d, self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), + ], + device=self._fabric_device, + ) + self._fabric_array = wp.fabricarray(self._world_selection, self._WORLD_MATRIX_ATTR) + else: + self._world_selection.PrepareForReuse() + + return self._fabric_array + + def _get_local_matrices_as_fabricarray(self) -> wp.fabricarray: + """Create a fresh :class:`wp.fabricarray` backed by ``omni:fabric:localMatrix``.""" + import usdrt + + if True: + self._local_selection = self._fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), + (usdrt.Sdf.ValueTypeNames.Matrix4d, self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), + ], + device=self._fabric_device, + ) + self._fabric_local_array = wp.fabricarray(self._local_selection, self._LOCAL_MATRIX_ATTR) + else: + self._local_selection.PrepareForReuse() + + return self._fabric_local_array + + def _resolve_indices_wp(self, indices: Sequence[int] | None) -> wp.array: + """Convert view indices to a Warp :class:`wp.array`.""" + if indices is None or indices == slice(None): + if self._default_view_indices is None: + raise RuntimeError("Fabric indices are not initialized.") + return self._default_view_indices + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + return wp.array(indices_list, dtype=wp.uint32).to(self._device) diff --git a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py index 8c26085f6c1b..5f6276166d8e 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py +++ b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py @@ -8,15 +8,16 @@ import logging from collections.abc import Sequence -import numpy as np import torch -import warp as wp -from pxr import Gf, Sdf, Usd, UsdGeom, Vt +from pxr import Sdf, Usd, UsdGeom import isaaclab.sim as sim_utils from isaaclab.app.settings_manager import SettingsManager -from isaaclab.utils.warp import fabric as fabric_utils + +from .xform_backend import XformBackend +from .xform_fabric_backend import FabricBackend +from .xform_usd_backend import UsdBackend logger = logging.getLogger(__name__) @@ -80,14 +81,6 @@ class XformPrimView: time-sampled keyframes separately. """ - # -- Fabric attribute names -- - _WORLD_MATRIX_ATTR = "omni:fabric:worldMatrix" - _LOCAL_MATRIX_ATTR = "omni:fabric:localMatrix" - - _shared_fabric_hierarchy = None - _shared_fabric_stage_key: int | None = None - _world_xforms_dirty: bool = False - def __init__( self, prim_path: str, @@ -125,7 +118,6 @@ def __init__( ValueError: If any matched prim is not Xformable or doesn't have standardized transform operations (translate, orient, scale in that order). """ - # Store configuration self._prim_path = prim_path self._device = device @@ -133,7 +125,6 @@ def __init__( stage = sim_utils.get_current_stage() if stage is None else stage self._prims: list[Usd.Prim] = sim_utils.find_matching_prims(prim_path, stage=stage) - # Validate all prims have standard xform operations if validate_xform_ops: for prim in self._prims: sim_utils.standardize_xform_ops(prim) @@ -144,46 +135,38 @@ def __init__( " Use sim_utils.standardize_xform_ops() to prepare the prim." ) - # Determine if Fabric is supported on the device + # Determine whether Fabric is available settings = SettingsManager.instance() - self._use_fabric = bool(settings.get("/physics/fabricEnabled", False)) + use_fabric = bool(settings.get("/physics/fabricEnabled", False)) - # Check for unsupported Fabric + non-primary CUDA device combination. - # USDRT SelectPrims and Warp fabric arrays only support cuda:0 internally. - # When running on cuda:1 or higher, SelectPrims raises a C++ error regardless of - # the device argument, because USDRT uses the active CUDA context (which is cuda:1). - if self._use_fabric and self._device not in ("cpu", "cuda", "cuda:0"): + if use_fabric and self._device not in ("cpu", "cuda", "cuda:0"): logger.warning( f"Fabric mode is not supported on device '{self._device}'. " "USDRT SelectPrims and Warp fabric arrays only support cuda:0. " "Falling back to standard USD operations. This may impact performance." ) - self._use_fabric = False + use_fabric = False - # Create indices buffer - # Since we iterate over the indices, we need to use range instead of torch tensor + # Index list used by visibility (USD-only) self._ALL_INDICES = list(range(len(self._prims))) - # Some prims (e.g., Cameras) require USD-authored transforms for rendering. - # When enabled, mirror Fabric pose writes to USD for those prims. - self._sync_usd_on_fabric_write = sync_usd_on_fabric_write - - # Fabric batch infrastructure (initialized lazily on first use) - self._fabric_initialized = False - self._fabric_usd_sync_done = False - self._view_to_fabric: wp.array | None = None - self._default_view_indices: wp.array | None = None - self._fabric_hierarchy = None - # Create a valid USD attribute name: namespace:name - # Use "isaaclab" namespace to identify our custom attributes - self._view_index_attr = f"isaaclab:view_index:{abs(hash(self))}" - - self._world_selection = None - self._local_selection = None + # ---- Create backends ------------------------------------------------ + if use_fabric: + self._backend: XformBackend = FabricBackend( + self._prims, self._device + ) + self._sync_backends: list[XformBackend] = ( + [UsdBackend(self._prims, self._device)] + if sync_usd_on_fabric_write + else [] + ) + else: + self._backend = UsdBackend(self._prims, self._device) + self._sync_backends = [] - """ - Properties. - """ + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ @property def count(self) -> int: @@ -212,16 +195,13 @@ def prim_paths(self) -> list[str]: to the USD prim objects without the conversion overhead. This property is mainly useful for logging, debugging, or when string paths are explicitly required. """ - # we cache it the first time it is accessed. - # we don't compute it in constructor because it is expensive and we don't need it most of the time. - # users should usually deal with prims directly as they typically need to access the prims directly. if not hasattr(self, "_prim_paths"): self._prim_paths = [prim.GetPath().pathString for prim in self._prims] return self._prim_paths - """ - Operations - Setters. - """ + # ------------------------------------------------------------------ + # Operations – Setters + # ------------------------------------------------------------------ def set_world_poses( self, @@ -251,10 +231,9 @@ def set_world_poses( ValueError: If positions shape is not (M, 3) or orientations shape is not (M, 4). ValueError: If the number of poses doesn't match the number of indices provided. """ - if self._use_fabric: - self._set_world_poses_fabric(positions, orientations, indices) - else: - self._set_world_poses_usd(positions, orientations, indices) + self._backend.set_world_poses(positions, orientations, indices) + for sync in self._sync_backends: + sync.set_world_poses(positions, orientations, indices) def set_local_poses( self, @@ -291,10 +270,9 @@ def set_local_poses( ValueError: If translations shape is not (M, 3) or orientations shape is not (M, 4). ValueError: If the number of poses doesn't match the number of indices provided. """ - if self._use_fabric: - self._set_local_poses_fabric(translations, orientations, indices) - else: - self._set_local_poses_usd(translations, orientations, indices) + self._backend.set_local_poses(translations, orientations, indices) + for sync in self._sync_backends: + sync.set_local_poses(translations, orientations, indices) def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None): """Set scales for prims in the view. @@ -313,10 +291,9 @@ def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) Raises: ValueError: If scales shape is not (M, 3). """ - if self._use_fabric: - self._set_scales_fabric(scales, indices) - else: - self._set_scales_usd(scales, indices) + self._backend.set_scales(scales, indices) + for sync in self._sync_backends: + sync.set_scales(scales, indices) def set_visibility(self, visibility: torch.Tensor, indices: Sequence[int] | None = None): """Set visibility for prims in the view. @@ -332,30 +309,25 @@ def set_visibility(self, visibility: torch.Tensor, indices: Sequence[int] | None Raises: ValueError: If visibility shape is not (M,). """ - # Resolve indices if indices is None or indices == slice(None): indices_list = self._ALL_INDICES else: indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - # Validate inputs if visibility.shape != (len(indices_list),): raise ValueError(f"Expected visibility shape ({len(indices_list)},), got {visibility.shape}.") - # Set visibility for each prim with Sdf.ChangeBlock(): for idx, prim_idx in enumerate(indices_list): - # Convert prim to imageable imageable = UsdGeom.Imageable(self._prims[prim_idx]) - # Set visibility if visibility[idx]: imageable.MakeVisible() else: imageable.MakeInvisible() - """ - Operations - Getters. - """ + # ------------------------------------------------------------------ + # Operations – Getters + # ------------------------------------------------------------------ def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: """Get world-space poses for prims in the view. @@ -380,10 +352,7 @@ def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T where M is the number of prims queried. - orientations: Torch tensor of shape (M, 4) containing world-space quaternions (w, x, y, z) """ - if self._use_fabric: - return self._get_world_poses_fabric(indices) - else: - return self._get_world_poses_usd(indices) + return self._backend.get_world_poses(indices) def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: """Get local-space poses for prims in the view. @@ -410,10 +379,7 @@ def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T - orientations: Torch tensor of shape (M, 4) containing local-space quaternions (x, y, z, w). """ - if self._use_fabric: - return self._get_local_poses_fabric(indices) - else: - return self._get_local_poses_usd(indices) + return self._backend.get_local_poses(indices) def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: """Get scales for prims in the view. @@ -431,10 +397,7 @@ def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: Returns: A tensor of shape (M, 3) containing the scales of each prim, where M is the number of prims queried. """ - if self._use_fabric: - return self._get_scales_fabric(indices) - else: - return self._get_scales_usd(indices) + return self._backend.get_scales(indices) def get_visibility(self, indices: Sequence[int] | None = None) -> torch.Tensor: """Get visibility for prims in the view. @@ -449,855 +412,15 @@ def get_visibility(self, indices: Sequence[int] | None = None) -> torch.Tensor: A tensor of shape (M,) containing the visibility of each prim, where M is the number of prims queried. The tensor is of type bool. """ - # Resolve indices if indices is None or indices == slice(None): indices_list = self._ALL_INDICES else: - # Convert to list if it is a tensor array indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - # Create buffers visibility = torch.zeros(len(indices_list), dtype=torch.bool, device=self._device) for idx, prim_idx in enumerate(indices_list): - # Get prim imageable = UsdGeom.Imageable(self._prims[prim_idx]) - # Get visibility visibility[idx] = imageable.ComputeVisibility() != UsdGeom.Tokens.invisible return visibility - - """ - Internal Functions - USD. - """ - - def _set_world_poses_usd( - self, - positions: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set world poses to USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - # Convert to list if it is a tensor array - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Validate inputs - if positions is not None: - if positions.shape != (len(indices_list), 3): - raise ValueError( - f"Expected positions shape ({len(indices_list)}, 3), got {positions.shape}. " - "Number of positions must match the number of prims in the view." - ) - positions_array = Vt.Vec3dArray.FromNumpy(positions.cpu().numpy()) - else: - positions_array = None - if orientations is not None: - if orientations.shape != (len(indices_list), 4): - raise ValueError( - f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}. " - "Number of orientations must match the number of prims in the view." - ) - # Vt expects quaternions in xyzw order - orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) - else: - orientations_array = None - - # Create xform cache instance - xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) - - # Set poses for each prim - # We use Sdf.ChangeBlock to minimize notification overhead. - with Sdf.ChangeBlock(): - for idx, prim_idx in enumerate(indices_list): - # Get prim - prim = self._prims[prim_idx] - # Get parent prim for local space conversion - parent_prim = prim.GetParent() - - # Determine what to set - world_pos = positions_array[idx] if positions_array is not None else None - world_quat = orientations_array[idx] if orientations_array is not None else None - - # Convert world pose to local if we have a valid parent - # Note: We don't use :func:`isaaclab.sim.utils.transforms.convert_world_pose_to_local` - # here since it isn't optimized for batch operations. - if parent_prim.IsValid() and parent_prim.GetPath() != Sdf.Path.absoluteRootPath: - # Get current world pose if we're only setting one component - if positions_array is None or orientations_array is None: - # get prim xform - prim_tf = xform_cache.GetLocalToWorldTransform(prim) - # sanitize quaternion - # this is needed, otherwise the quaternion might be non-normalized - prim_tf.Orthonormalize() - # populate desired world transform - if world_pos is not None: - prim_tf.SetTranslateOnly(world_pos) - if world_quat is not None: - prim_tf.SetRotateOnly(world_quat) - else: - # Both position and orientation are provided, create new transform - prim_tf = Gf.Matrix4d() - prim_tf.SetTranslateOnly(world_pos) - prim_tf.SetRotateOnly(world_quat) - - # Convert to local space - parent_world_tf = xform_cache.GetLocalToWorldTransform(parent_prim) - local_tf = prim_tf * parent_world_tf.GetInverse() - local_pos = local_tf.ExtractTranslation() - local_quat = local_tf.ExtractRotationQuat() - else: - # No parent or parent is root, world == local - local_pos = world_pos - local_quat = world_quat - - # Get or create the standard transform operations - if local_pos is not None: - prim.GetAttribute("xformOp:translate").Set(local_pos) - if local_quat is not None: - prim.GetAttribute("xformOp:orient").Set(local_quat) - - def _set_local_poses_usd( - self, - translations: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set local poses to USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Validate inputs - if translations is not None: - if translations.shape != (len(indices_list), 3): - raise ValueError(f"Expected translations shape ({len(indices_list)}, 3), got {translations.shape}.") - translations_array = Vt.Vec3dArray.FromNumpy(translations.cpu().numpy()) - else: - translations_array = None - if orientations is not None: - if orientations.shape != (len(indices_list), 4): - raise ValueError(f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}.") - orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) - else: - orientations_array = None - - # Set local poses - with Sdf.ChangeBlock(): - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - if translations_array is not None: - prim.GetAttribute("xformOp:translate").Set(translations_array[idx]) - if orientations_array is not None: - prim.GetAttribute("xformOp:orient").Set(orientations_array[idx]) - - def _set_scales_usd(self, scales: torch.Tensor, indices: Sequence[int] | None = None): - """Set scales to USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Validate inputs - if scales.shape != (len(indices_list), 3): - raise ValueError(f"Expected scales shape ({len(indices_list)}, 3), got {scales.shape}.") - - scales_array = Vt.Vec3dArray.FromNumpy(scales.cpu().numpy()) - # Set scales for each prim - with Sdf.ChangeBlock(): - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - prim.GetAttribute("xformOp:scale").Set(scales_array[idx]) - - def _get_world_poses_usd(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get world poses from USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - # Convert to list if it is a tensor array - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Create buffers - positions = Vt.Vec3dArray(len(indices_list)) - orientations = Vt.QuatdArray(len(indices_list)) - # Create xform cache instance - xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) - - # Note: We don't use :func:`isaaclab.sim.utils.transforms.resolve_prim_pose` - # here since it isn't optimized for batch operations. - for idx, prim_idx in enumerate(indices_list): - # Get prim - prim = self._prims[prim_idx] - # get prim xform - prim_tf = xform_cache.GetLocalToWorldTransform(prim) - # sanitize quaternion - # this is needed, otherwise the quaternion might be non-normalized - prim_tf.Orthonormalize() - # extract position and orientation - positions[idx] = prim_tf.ExtractTranslation() - orientations[idx] = prim_tf.ExtractRotationQuat() - - # move to torch tensors - positions = torch.tensor(np.array(positions), dtype=torch.float32, device=self._device) - orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) - return positions, orientations # type: ignore - - def _get_local_poses_usd(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get local poses from USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Create buffers - translations = Vt.Vec3dArray(len(indices_list)) - orientations = Vt.QuatdArray(len(indices_list)) - - # Create a fresh XformCache to avoid stale cached values - xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) - - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - prim_tf = xform_cache.GetLocalTransformation(prim)[0] - prim_tf.Orthonormalize() - translations[idx] = prim_tf.ExtractTranslation() - orientations[idx] = prim_tf.ExtractRotationQuat() - - translations = torch.tensor(np.array(translations), dtype=torch.float32, device=self._device) - orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) - return translations, orientations # type: ignore - - def _get_scales_usd(self, indices: Sequence[int] | None = None) -> torch.Tensor: - """Get scales from USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Create buffers - scales = Vt.Vec3dArray(len(indices_list)) - - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - scales[idx] = prim.GetAttribute("xformOp:scale").Get() - - # Convert to tensor - return torch.tensor(np.array(scales), dtype=torch.float32, device=self._device) - - """ - Internal Functions - Fabric. - """ - - def _set_world_poses_fabric( - self, - positions: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set world poses using Fabric GPU batch operations. - - Writes directly to Fabric's ``omni:fabric:worldMatrix`` attribute using Warp kernels. - Changes are propagated through Fabric's hierarchy system but remain GPU-resident. - - For workflows mixing Fabric world pose writes with USD local pose queries, note - that local poses read from USD's xformOp:* attributes, which may not immediately - reflect Fabric changes. For best performance and consistency, use Fabric methods - exclusively (get_world_poses/set_world_poses with Fabric enabled). - """ - # Lazy initialization - if not self._fabric_initialized: - self._initialize_fabric() - # Flush pending local→world propagation so the explicit world write - # is not later overwritten by a deferred update_world_xforms() cascade. - if XformPrimView._world_xforms_dirty: - self._fabric_hierarchy.update_world_xforms() - XformPrimView._world_xforms_dirty = False - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Convert torch to warp (if provided), use dummy arrays for None to avoid Warp kernel issues - if positions is not None: - positions_wp = wp.from_torch(positions) - else: - positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - if orientations is not None: - orientations_wp = wp.from_torch(orientations) - else: - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - - # Dummy array for scales (not modifying) - scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - world_matrices = self._get_world_matrices_as_fabricarray() - - # Batch compose matrices with a single kernel launch - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, - orientations_wp, - scales_wp, # dummy array instead of None - False, # broadcast_positions - False, # broadcast_orientations - False, # broadcast_scales - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Synchronize to ensure kernel completes - wp.synchronize() - - # Fabric now has authoritative data; skip future USD syncs - self._fabric_usd_sync_done = True - # Mirror to USD for renderer-facing prims when enabled. - if self._sync_usd_on_fabric_write: - self._set_world_poses_usd(positions, orientations, indices) - - # Fabric writes are GPU-resident; local pose operations still use USD. - - def _set_local_poses_fabric( - self, - translations: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set local poses using Fabric GPU batch operations. - - Composes ``omni:fabric:localMatrix`` via a Warp kernel (GPU-batched, - handles partial updates), then re-registers each matrix through - :meth:`IFabricHierarchy.set_local_xform` so that Fabric's change - tracking picks it up. Finally calls ``update_world_xforms()`` which - propagates world matrices for every prim whose ``localMatrix`` changed. - """ - if not self._fabric_initialized: - self._initialize_fabric() - - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - if translations is not None: - translations_wp = wp.from_torch(translations) - else: - translations_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - if orientations is not None: - orientations_wp = wp.from_torch(orientations) - else: - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - - scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - local_matrices = self._get_local_matrices_as_fabricarray() - - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - local_matrices, - translations_wp, - orientations_wp, - scales_wp, - False, - False, - False, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - wp.synchronize() - - XformPrimView._world_xforms_dirty = True - self._fabric_usd_sync_done = True - if self._sync_usd_on_fabric_write: - self._set_local_poses_usd(translations, orientations, indices) - - def _set_scales_fabric(self, scales: torch.Tensor, indices: Sequence[int] | None = None): - """Set scales using Fabric GPU batch operations.""" - # Lazy initialization - if not self._fabric_initialized: - self._initialize_fabric() - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Convert torch to warp - scales_wp = wp.from_torch(scales) - - # Dummy arrays for positions and orientations (not modifying) - positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - - world_matrices = self._get_world_matrices_as_fabricarray() - - # Batch compose matrices on GPU with a single kernel launch - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, # dummy array instead of None - orientations_wp, # dummy array instead of None - scales_wp, - False, # broadcast_positions - False, # broadcast_orientations - False, # broadcast_scales - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Synchronize to ensure kernel completes before syncing - wp.synchronize() - - # Fabric now has authoritative data; skip future USD syncs - self._fabric_usd_sync_done = True - # Mirror to USD for renderer-facing prims when enabled. - if self._sync_usd_on_fabric_write: - self._set_scales_usd(scales, indices) - - def _get_world_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get world poses from Fabric using GPU batch operations.""" - # Lazy initialization of Fabric infrastructure - if not self._fabric_initialized: - self._initialize_fabric() - # Sync once from USD to ensure reads see the latest authored transforms - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() - # Propagate local→world if any local poses were modified since last read - if XformPrimView._world_xforms_dirty: - self._fabric_hierarchy.update_world_xforms() - XformPrimView._world_xforms_dirty = False - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Use pre-allocated buffers for full reads, allocate only for partial reads - use_cached_buffers = indices is None or indices == slice(None) - if use_cached_buffers: - # Full read: Use cached buffers (zero allocation overhead!) - positions_wp = self._fabric_positions_buffer - orientations_wp = self._fabric_orientations_buffer - scales_wp = self._fabric_dummy_buffer - else: - # Partial read: Need to allocate buffers of appropriate size - positions_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) - orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) - scales_wp = self._fabric_dummy_buffer # Always use dummy for scales - - world_matrices = self._get_world_matrices_as_fabricarray() - - # Launch GPU kernel to decompose matrices in parallel - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, - orientations_wp, - scales_wp, # dummy array instead of None - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Return tensors: zero-copy for cached buffers, conversion for partial reads - if use_cached_buffers: - # Zero-copy! The Warp kernel wrote directly into the PyTorch tensors - # We just need to synchronize to ensure the kernel is done - wp.synchronize() - return self._fabric_positions_torch, self._fabric_orientations_torch - else: - # Partial read: Need to convert from Warp to torch - positions = wp.to_torch(positions_wp) - orientations = wp.to_torch(orientations_wp) - return positions, orientations - - def _get_local_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get local poses from Fabric using GPU batch operations. - - Reads ``omni:fabric:localMatrix`` and decomposes each 4x4 matrix into - translation and orientation, mirroring the world-pose Fabric path. - """ - if not self._fabric_initialized: - self._initialize_fabric() - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() - - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - use_cached_buffers = indices is None or indices == slice(None) - if use_cached_buffers: - translations_wp = self._fabric_local_translations_buffer - orientations_wp = self._fabric_local_orientations_buffer - scales_wp = self._fabric_dummy_buffer - else: - translations_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) - orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) - scales_wp = self._fabric_dummy_buffer - - local_matrices = self._get_local_matrices_as_fabricarray() - - wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, - inputs=[ - local_matrices, - translations_wp, - orientations_wp, - scales_wp, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - if use_cached_buffers: - wp.synchronize() - return self._fabric_local_translations_torch, self._fabric_local_orientations_torch - else: - translations = wp.to_torch(translations_wp) - orientations = wp.to_torch(orientations_wp) - return translations, orientations - - def _get_scales_fabric(self, indices: Sequence[int] | None = None) -> torch.Tensor: - """Get scales from Fabric using GPU batch operations.""" - # Lazy initialization - if not self._fabric_initialized: - self._initialize_fabric() - # Sync once from USD to ensure reads see the latest authored transforms - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Use pre-allocated buffers for full reads, allocate only for partial reads - use_cached_buffers = indices is None or indices == slice(None) - if use_cached_buffers: - # Full read: Use cached buffers (zero allocation overhead!) - scales_wp = self._fabric_scales_buffer - else: - # Partial read: Need to allocate buffer of appropriate size - scales_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) - - # Always use dummy buffers for positions and orientations (not needed for scales) - positions_wp = self._fabric_dummy_buffer - orientations_wp = self._fabric_dummy_buffer - - world_matrices = self._get_world_matrices_as_fabricarray() - - # Launch GPU kernel to decompose matrices in parallel - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, # dummy array instead of None - orientations_wp, # dummy array instead of None - scales_wp, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Return tensor: zero-copy for cached buffers, conversion for partial reads - if use_cached_buffers: - # Zero-copy! The Warp kernel wrote directly into the PyTorch tensor - wp.synchronize() - return self._fabric_scales_torch - else: - # Partial read: Need to convert from Warp to torch - return wp.to_torch(scales_wp) - - """ - Internal Functions - Initialization. - """ - - def _initialize_fabric(self) -> None: - """Initialize Fabric batch infrastructure for GPU-accelerated pose queries. - - This method ensures all prims have the required Fabric hierarchy attributes - (``omni:fabric:localMatrix`` and ``omni:fabric:worldMatrix``) and creates the necessary - infrastructure for batch GPU operations using Warp. - - Based on the Fabric Hierarchy documentation, when Fabric Scene Delegate is enabled, - all boundable prims should have these attributes. This method ensures they exist - and are properly synchronized with USD. - """ - import usdrt - from usdrt import Rt - - # Get USDRT (Fabric) stage - stage_id = sim_utils.get_current_stage_id() - fabric_stage = usdrt.Usd.Stage.Attach(stage_id) - - # Step 1: Ensure all prims have Fabric hierarchy attributes - # According to the documentation, these attributes are created automatically - # when Fabric Scene Delegate is enabled, but we ensure they exist - for i in range(self.count): - rt_prim = fabric_stage.GetPrimAtPath(self.prim_paths[i]) - - # Create view index attribute for batch operations - rt_prim.CreateAttribute(self._view_index_attr, usdrt.Sdf.ValueTypeNames.UInt, custom=True) - rt_prim.GetAttribute(self._view_index_attr).Set(i) - - # Share a single hierarchy handle across all instances; rebuild when stage changes. - if XformPrimView._shared_fabric_hierarchy is None or XformPrimView._shared_fabric_stage_key != stage_id: - # Populate Fabric from USD once to establish parent-child connectivity. - # Without this, IFabricHierarchy cannot traverse the hierarchy and - # update_world_xforms() will not propagate parent transforms to children. - pop = usdrt.population.IUtils() - pop.set_enable_usd_notice_handling( - fabric_stage.GetStageIdAsStageId(), fabric_stage.GetFabricId(), True - ) - pop.populate_from_usd( - fabric_stage.GetStageReaderWriterId(), - fabric_stage.GetStageIdAsStageId(), - usdrt.Sdf.Path("/"), - 0, - ) - pop.apply_pending_usd_updates( - fabric_stage.GetStageIdAsStageId(), fabric_stage.GetStageReaderWriterId(), 0 - ) - - XformPrimView._shared_fabric_hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( - fabric_stage.GetFabricId(), fabric_stage.GetStageIdAsStageId() - ) - XformPrimView._shared_fabric_hierarchy.update_world_xforms() - XformPrimView._shared_fabric_hierarchy.track_local_xform_changes(True) - XformPrimView._shared_fabric_hierarchy.track_world_xform_changes(True) - XformPrimView._shared_fabric_stage_key = stage_id - self._fabric_hierarchy = XformPrimView._shared_fabric_hierarchy - - # Step 2: Create index arrays for batch operations - self._default_view_indices = wp.zeros((self.count,), dtype=wp.uint32).to(self._device) - wp.launch( - kernel=fabric_utils.arange_k, - dim=self.count, - inputs=[self._default_view_indices], - device=self._device, - ) - wp.synchronize() # Ensure indices are ready - - # Step 3: Create Fabric selection with attribute filtering - # SelectPrims expects device format like "cuda:0" not "cuda" - # - # KNOWN ISSUE: SelectPrims may return prims in a different order than self._prims - # (which comes from USD's find_matching_prims). We create a bidirectional mapping - # (_view_to_fabric and _fabric_to_view) to handle this ordering difference. - # This works correctly for full-view operations but partial indexing still has issues. - # - # NOTE: SelectPrims only supports "cuda:0" regardless of which GPU the simulation - # is running on. In multi-GPU setups, we must use "cuda:0" for SelectPrims even if - # the simulation device is "cuda:1" or higher. - fabric_device = self._device - if self._device == "cuda": - logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.") - fabric_device = "cuda:0" - elif self._device.startswith("cuda:"): - # SelectPrims only supports cuda:0, so we always use cuda:0 for SelectPrims - # even if the simulation is running on a different GPU - if self._device != "cuda:0": - logger.debug( - f"SelectPrims only supports cuda:0. Using cuda:0 for SelectPrims " - f"even though simulation device is {self._device}." - ) - fabric_device = "cuda:0" - - index_selection = fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - ], - device=fabric_device, - ) - - # Step 4: Create bidirectional mapping between view and fabric indices - # Note: fabric_to_view is tied to fabric_device (cuda:0) because it's created from SelectPrims. - # view_to_fabric must also be on fabric_device since it's always used with fabricarrays in kernels. - self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32).to(fabric_device) - fabric_to_view = wp.fabricarray(index_selection, self._view_index_attr) - - wp.launch( - kernel=fabric_utils.set_view_to_fabric_array, - dim=fabric_to_view.shape[0], - inputs=[fabric_to_view, self._view_to_fabric], - device=fabric_device, - ) - # Synchronize to ensure mapping is ready before any operations - wp.synchronize() - - # Pre-allocate reusable output buffers for read operations - self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) - self._fabric_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) - self._fabric_scales_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) - - # Create Warp views of the PyTorch tensors - self._fabric_positions_buffer = wp.from_torch(self._fabric_positions_torch, dtype=wp.float32) - self._fabric_orientations_buffer = wp.from_torch(self._fabric_orientations_torch, dtype=wp.float32) - self._fabric_scales_buffer = wp.from_torch(self._fabric_scales_torch, dtype=wp.float32) - - # Pre-allocate reusable output buffers for local pose reads - self._fabric_local_translations_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) - self._fabric_local_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) - - self._fabric_local_translations_buffer = wp.from_torch( - self._fabric_local_translations_torch, dtype=wp.float32 - ) - self._fabric_local_orientations_buffer = wp.from_torch( - self._fabric_local_orientations_torch, dtype=wp.float32 - ) - - # Dummy array for unused outputs (always empty) - self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - # Cached selection / fabricarray for local matrices (lazy-initialized) - self._local_selection = None - self._fabric_local_array = None - - # Cache Fabric stage to avoid expensive get_current_stage() calls - self._fabric_stage = fabric_stage - - # Store fabric_device for use in kernel launches that involve fabricarrays - self._fabric_device = fabric_device - - self._fabric_initialized = True - # Force a one-time USD->Fabric sync on first read to pick up any USD edits - # made after the view was constructed. - self._fabric_usd_sync_done = False - - def _get_world_matrices_as_fabricarray(self) -> wp.fabricarray: - """Create a fresh fabricarray for world matrices. - - Recreating both the PrimSelection and fabricarray on each write ensures Fabric's - journaling system marks the attribute as dirty, so downstream consumers (renderers) - observe the update. - """ - import usdrt - - # Full rebuild of select prim, there is a bug in PrepareForReuse that - # causes CPU buffers to take precedence over GPU buffers. - if True: - self._world_selection = self._fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), - ], - device=self._fabric_device, - ) - self._fabric_array = wp.fabricarray(self._world_selection, self._WORLD_MATRIX_ATTR) - else: - self._world_selection.PrepareForReuse() - - return self._fabric_array - - def _get_local_matrices_as_fabricarray(self) -> wp.fabricarray: - """Create a fresh fabricarray for local matrices. - - Mirrors :meth:`_get_world_matrices_as_fabricarray` but targets - ``omni:fabric:localMatrix``. - """ - import usdrt - - # Full rebuild of select prim, there is a bug in PrepareForReuse that - # causes CPU buffers to take precedence over GPU buffers. - if True: - self._local_selection = self._fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), - ], - device=self._fabric_device, - ) - self._fabric_local_array = wp.fabricarray(self._local_selection, self._LOCAL_MATRIX_ATTR) - else: - self._local_selection.PrepareForReuse() - - return self._fabric_local_array - - def _sync_fabric_from_usd_once(self) -> None: - """Sync Fabric world matrices from USD once, on the first read.""" - # Ensure Fabric is initialized - if not self._fabric_initialized: - self._initialize_fabric() - - # Read authoritative transforms from USD and write once into Fabric. - positions_usd, orientations_usd = self._get_world_poses_usd() - scales_usd = self._get_scales_usd() - - prev_sync = self._sync_usd_on_fabric_write - self._sync_usd_on_fabric_write = False - self._set_world_poses_fabric(positions_usd, orientations_usd) - self._set_scales_fabric(scales_usd) - self._sync_usd_on_fabric_write = prev_sync - - # Also sync local matrices so _get_local_poses_fabric() returns - # correct values. Write directly to localMatrix without calling - # update_world_xforms() — the world matrices above are already correct. - local_trans_usd, local_orient_usd = self._get_local_poses_usd() - local_trans_wp = wp.from_torch(local_trans_usd) - local_orient_wp = wp.from_torch(local_orient_usd) - scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - indices_wp = self._resolve_indices_wp(None) - local_matrices = self._get_local_matrices_as_fabricarray() - - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=self.count, - inputs=[ - local_matrices, - local_trans_wp, - local_orient_wp, - scales_wp, - False, - False, - False, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - wp.synchronize() - - self._fabric_usd_sync_done = True - - def _resolve_indices_wp(self, indices: Sequence[int] | None) -> wp.array: - """Resolve view indices as a Warp array.""" - if indices is None or indices == slice(None): - if self._default_view_indices is None: - raise RuntimeError("Fabric indices are not initialized.") - return self._default_view_indices - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - return wp.array(indices_list, dtype=wp.uint32).to(self._device) diff --git a/source/isaaclab/isaaclab/sim/views/xform_usd_backend.py b/source/isaaclab/isaaclab/sim/views/xform_usd_backend.py new file mode 100644 index 000000000000..f03b59e2a21e --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_usd_backend.py @@ -0,0 +1,220 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +import torch + +from pxr import Gf, Sdf, Usd, UsdGeom, Vt + + +class UsdBackend: + """USD-based transform backend for :class:`XformPrimView`. + + Reads and writes transforms through USD's ``xformOp`` attributes, using + :class:`UsdGeom.XformCache` for efficient world-space computations. + """ + + def __init__(self, prims: list[Usd.Prim], device: str): + self._prims = prims + self._device = device + self._ALL_INDICES = list(range(len(prims))) + + @property + def count(self) -> int: + """Number of prims managed by this backend.""" + return len(self._prims) + + def initialize(self) -> None: + """No-op for the USD backend (no deferred setup needed).""" + + # ------------------------------------------------------------------ + # Setters + # ------------------------------------------------------------------ + + def set_world_poses( + self, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set world-space poses by converting to local space and writing to USD.""" + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Validate inputs + if positions is not None: + if positions.shape != (len(indices_list), 3): + raise ValueError( + f"Expected positions shape ({len(indices_list)}, 3), got {positions.shape}. " + "Number of positions must match the number of prims in the view." + ) + positions_array = Vt.Vec3dArray.FromNumpy(positions.cpu().numpy()) + else: + positions_array = None + if orientations is not None: + if orientations.shape != (len(indices_list), 4): + raise ValueError( + f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}. " + "Number of orientations must match the number of prims in the view." + ) + orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) + else: + orientations_array = None + + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + parent_prim = prim.GetParent() + + world_pos = positions_array[idx] if positions_array is not None else None + world_quat = orientations_array[idx] if orientations_array is not None else None + + if parent_prim.IsValid() and parent_prim.GetPath() != Sdf.Path.absoluteRootPath: + if positions_array is None or orientations_array is None: + prim_tf = xform_cache.GetLocalToWorldTransform(prim) + prim_tf.Orthonormalize() + if world_pos is not None: + prim_tf.SetTranslateOnly(world_pos) + if world_quat is not None: + prim_tf.SetRotateOnly(world_quat) + else: + prim_tf = Gf.Matrix4d() + prim_tf.SetTranslateOnly(world_pos) + prim_tf.SetRotateOnly(world_quat) + + parent_world_tf = xform_cache.GetLocalToWorldTransform(parent_prim) + local_tf = prim_tf * parent_world_tf.GetInverse() + local_pos = local_tf.ExtractTranslation() + local_quat = local_tf.ExtractRotationQuat() + else: + local_pos = world_pos + local_quat = world_quat + + if local_pos is not None: + prim.GetAttribute("xformOp:translate").Set(local_pos) + if local_quat is not None: + prim.GetAttribute("xformOp:orient").Set(local_quat) + + def set_local_poses( + self, + translations: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set local-space poses directly on USD xformOp attributes.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + if translations is not None: + if translations.shape != (len(indices_list), 3): + raise ValueError(f"Expected translations shape ({len(indices_list)}, 3), got {translations.shape}.") + translations_array = Vt.Vec3dArray.FromNumpy(translations.cpu().numpy()) + else: + translations_array = None + if orientations is not None: + if orientations.shape != (len(indices_list), 4): + raise ValueError(f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}.") + orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) + else: + orientations_array = None + + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + if translations_array is not None: + prim.GetAttribute("xformOp:translate").Set(translations_array[idx]) + if orientations_array is not None: + prim.GetAttribute("xformOp:orient").Set(orientations_array[idx]) + + def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) -> None: + """Set scales on USD ``xformOp:scale`` attributes.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + if scales.shape != (len(indices_list), 3): + raise ValueError(f"Expected scales shape ({len(indices_list)}, 3), got {scales.shape}.") + + scales_array = Vt.Vec3dArray.FromNumpy(scales.cpu().numpy()) + + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + prim.GetAttribute("xformOp:scale").Set(scales_array[idx]) + + # ------------------------------------------------------------------ + # Getters + # ------------------------------------------------------------------ + + def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Get world-space poses via :class:`UsdGeom.XformCache`.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + positions = Vt.Vec3dArray(len(indices_list)) + orientations = Vt.QuatdArray(len(indices_list)) + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + prim_tf = xform_cache.GetLocalToWorldTransform(prim) + prim_tf.Orthonormalize() + positions[idx] = prim_tf.ExtractTranslation() + orientations[idx] = prim_tf.ExtractRotationQuat() + + positions = torch.tensor(np.array(positions), dtype=torch.float32, device=self._device) + orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) + return positions, orientations # type: ignore + + def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Get local-space poses via :class:`UsdGeom.XformCache`.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + translations = Vt.Vec3dArray(len(indices_list)) + orientations = Vt.QuatdArray(len(indices_list)) + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + prim_tf = xform_cache.GetLocalTransformation(prim)[0] + prim_tf.Orthonormalize() + translations[idx] = prim_tf.ExtractTranslation() + orientations[idx] = prim_tf.ExtractRotationQuat() + + translations = torch.tensor(np.array(translations), dtype=torch.float32, device=self._device) + orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) + return translations, orientations # type: ignore + + def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: + """Get scales from USD ``xformOp:scale`` attributes.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + scales = Vt.Vec3dArray(len(indices_list)) + + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + scales[idx] = prim.GetAttribute("xformOp:scale").Get() + + return torch.tensor(np.array(scales), dtype=torch.float32, device=self._device) From 8afeb88674d2789baf0de9d47b11a12b808a2d14 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 31 Mar 2026 16:20:41 -0700 Subject: [PATCH 11/20] Reformat --- .../isaaclab/isaaclab/sensors/camera/camera.py | 4 +++- .../isaaclab/sim/views/xform_fabric_backend.py | 17 +++++------------ .../isaaclab/sim/views/xform_prim_view.py | 8 ++------ .../isaaclab/test/sensors/test_tiled_camera.py | 11 +++++++---- 4 files changed, 17 insertions(+), 23 deletions(-) diff --git a/source/isaaclab/isaaclab/sensors/camera/camera.py b/source/isaaclab/isaaclab/sensors/camera/camera.py index 5470162dceaa..f03c399500f2 100644 --- a/source/isaaclab/isaaclab/sensors/camera/camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/camera.py @@ -450,7 +450,9 @@ def _initialize_impl(self): super()._initialize_impl() # Create a view for the sensor with Fabric enabled for fast pose queries, otherwise position will be stale. self._view = XformPrimView( - self.cfg.prim_path, device=self._device, stage=self.stage, + self.cfg.prim_path, + device=self._device, + stage=self.stage, ) # Check that sizes are correct if self._view.count != self._num_envs: diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index 5372de458627..588469153584 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -112,12 +112,9 @@ def initialize(self) -> None: 0, ) - hierarchy = ( - usdrt.hierarchy.IFabricHierarchy() - .get_fabric_hierarchy( - fabric_stage.GetFabricId(), - fabric_stage.GetStageIdAsStageId(), - ) + hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( + fabric_stage.GetFabricId(), + fabric_stage.GetStageIdAsStageId(), ) hierarchy.update_world_xforms() hierarchy.track_local_xform_changes(True) @@ -182,12 +179,8 @@ def initialize(self) -> None: self._fabric_local_translations_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) self._fabric_local_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) - self._fabric_local_translations_buffer = wp.from_torch( - self._fabric_local_translations_torch, dtype=wp.float32 - ) - self._fabric_local_orientations_buffer = wp.from_torch( - self._fabric_local_orientations_torch, dtype=wp.float32 - ) + self._fabric_local_translations_buffer = wp.from_torch(self._fabric_local_translations_torch, dtype=wp.float32) + self._fabric_local_orientations_buffer = wp.from_torch(self._fabric_local_orientations_torch, dtype=wp.float32) # Dummy buffer for unused kernel outputs (always empty) self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) diff --git a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py index 5f6276166d8e..31296a81abb3 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py +++ b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py @@ -152,13 +152,9 @@ def __init__( # ---- Create backends ------------------------------------------------ if use_fabric: - self._backend: XformBackend = FabricBackend( - self._prims, self._device - ) + self._backend: XformBackend = FabricBackend(self._prims, self._device) self._sync_backends: list[XformBackend] = ( - [UsdBackend(self._prims, self._device)] - if sync_usd_on_fabric_write - else [] + [UsdBackend(self._prims, self._device)] if sync_usd_on_fabric_write else [] ) else: self._backend = UsdBackend(self._prims, self._device) diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index a09f766bd628..4bece04b9800 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -1620,10 +1620,13 @@ def test_output_equal_to_usd_camera_intrinsics(setup_camera, device): [(TiledCamera, TiledCameraCfg), (Camera, CameraCfg)], ids=["tiled", "non_tiled"], ) -@pytest.mark.parametrize("device", [ - "cpu", - "cuda:0", - ]) +@pytest.mark.parametrize( + "device", + [ + "cpu", + "cuda:0", + ], +) @pytest.mark.isaacsim_ci def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls, cfg_cls): """Test that moving a camera is reflected in rendered depth. From f5a3623852d217272413fa96edaf312998854ad4 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Tue, 31 Mar 2026 16:22:05 -0700 Subject: [PATCH 12/20] Revert unwanted changes --- .../isaaclab_physx/renderers/isaac_rtx_renderer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py b/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py index 5cf06657e0b8..22b07f13def0 100644 --- a/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py +++ b/source/isaaclab_physx/isaaclab_physx/renderers/isaac_rtx_renderer.py @@ -19,7 +19,6 @@ from isaaclab.app.settings_manager import get_settings_manager from isaaclab.renderers import BaseRenderer -from isaaclab.utils.math import convert_camera_frame_orientation_convention from isaaclab.utils.warp.kernels import reshape_tiled_image from .isaac_rtx_renderer_utils import ensure_isaac_rtx_render_update @@ -179,6 +178,8 @@ def update_camera( orientations: torch.Tensor, intrinsics: torch.Tensor, ): + """No-op for Replicator - uses USD camera prims directly. + See :meth:`~isaaclab.renderers.base_renderer.BaseRenderer.update_camera`.""" pass def render(self, render_data: IsaacRtxRenderData): From 6466ffebd0ac13fb2d78d6caa61b8ec638ad36a4 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Wed, 1 Apr 2026 17:24:38 -0700 Subject: [PATCH 13/20] Replace fabric array with indexedfabricarray --- .../sim/views/xform_fabric_backend.py | 502 ++++++++---------- source/isaaclab/isaaclab/utils/warp/fabric.py | 112 +++- 2 files changed, 322 insertions(+), 292 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index 588469153584..aeeda86c260a 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -24,6 +24,40 @@ class FabricBackend: Uses NVIDIA's Fabric API with Warp GPU kernels for high-performance batch transform operations. + + Selected primitives based on attributes such as local and world matrices. + (all prims in selection, could be in different buckets): + ┌─────────┬───────────────┬───────────┐ + │ Fab Idx │ Prim Path │ attribute │ + ├─────────┼───────────────┼───────────┤ + │ 0 │ /World/Light │ [ ... ] │ + │ 1 │ /World/Cam_2 │ [ ... ] │ + │ 2 │ /World/Ground │ [ ... ] │ + │ 3 │ /World/Cam_0 │ [ ... ] │ + │ 4 │ /World/Table │ [ ... ] │ + │ 5 │ /World/Cam_1 │ [ ... ] │ + │ 6 │ /World/Robot │ [ ... ] │ + └─────────┴───────────────┴───────────┘ + + Example view of 3 prims, order of the paths defines order of indices + ┌──────────┬──────────────┐ + │ View Idx │ Prim Path │ + ├──────────┼──────────────┤ + │ 0 │ /World/Cam_0 │ + │ 1 │ /World/Cam_1 │ + │ 2 │ /World/Cam_2 │ + └──────────┴──────────────┘ + + Mapping from view indices to fabric indices happens through a fabric index array: + ┌──────────┬─────────┐ + │ View Idx │ Fab Idx │ + ├──────────┼─────────┤ + │ 0 │ 3 │ + │ 1 │ 5 │ + │ 2 │ 1 │ + └──────────┴─────────┘ + + If topology of the fabric changes, then all fabric indices need to be rebuilt. """ _WORLD_MATRIX_ATTR = "omni:fabric:worldMatrix" @@ -39,58 +73,32 @@ def __init__( ): self._prims = prims self._device = device - self._view_index_attr = f"isaaclab:view_index:{abs(id(self))}" # Lazy-initialized state (None until initialize() runs) - self._view_to_fabric: wp.array | None = None - self._default_view_indices: wp.array | None = None - self._fabric_hierarchy = None - self._world_selection = None - self._local_selection = None - - # ------------------------------------------------------------------ - # Properties - # ------------------------------------------------------------------ - - @property - def count(self) -> int: - """Number of prims managed by this backend.""" - return len(self._prims) + self._view_indices: wp.array | None = None + self._fabric_indices: wp.array | None = None + self._selection = None - @property - def prim_paths(self) -> list[str]: - """Prim path strings (lazily cached).""" - if not hasattr(self, "_prim_paths_cache"): - self._prim_paths_cache = [p.GetPath().pathString for p in self._prims] - return self._prim_paths_cache - - # ------------------------------------------------------------------ - # Initialization - # ------------------------------------------------------------------ - - def initialize(self) -> None: - """Set up Fabric batch infrastructure for GPU-accelerated pose queries. - - Idempotent — subsequent calls after the first are no-ops. + # Resolve the Fabric device string (SelectPrims only supports cuda:0) + if self._device.startswith("cuda"): + if self._device == "cuda": + logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.") + elif self._device != "cuda:0": + logger.debug( + "SelectPrims only supports cuda:0. Using cuda:0 even though simulation device is %s.", + self._device, + ) + fabric_device = "cuda:0" + else: + fabric_device = self._device + self._fabric_device = fabric_device - Ensures all prims have the required Fabric hierarchy attributes - (``omni:fabric:localMatrix`` and ``omni:fabric:worldMatrix``) and - creates the index mapping, selections, and pre-allocated buffers - needed for Warp kernel launches. - """ - if self._fabric_hierarchy is not None: - return import usdrt from usdrt import Rt # noqa: F401 — imported for side-effects stage_id = sim_utils.get_current_stage_id() fabric_stage = usdrt.Usd.Stage.Attach(stage_id) - - # Ensure every prim carries the view-index attribute - for i in range(self.count): - rt_prim = fabric_stage.GetPrimAtPath(self.prim_paths[i]) - rt_prim.CreateAttribute(self._view_index_attr, usdrt.Sdf.ValueTypeNames.UInt, custom=True) - rt_prim.GetAttribute(self._view_index_attr).Set(i) + self._fabric_stage = fabric_stage # Reuse (or create) a hierarchy handle for this stage. if stage_id not in FabricBackend._hierarchy_cache: @@ -124,47 +132,12 @@ def initialize(self) -> None: self._fabric_hierarchy = FabricBackend._hierarchy_cache[stage_id] self._stage_id = stage_id - # Default view-index array (0 … count-1) - self._default_view_indices = wp.zeros((self.count,), dtype=wp.uint32).to(self._device) - wp.launch( - kernel=fabric_utils.arange_k, - dim=self.count, - inputs=[self._default_view_indices], - device=self._device, - ) - wp.synchronize() + # Build the view → fabric index array from PrimSelection path ordering. + # Default view-index array [0, 1, ..., count-1] for "all prims". + self._build_view_to_fabric_index_mapping() - # Resolve the Fabric device string (SelectPrims only supports cuda:0) - fabric_device = self._device - if self._device == "cuda": - logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.") - fabric_device = "cuda:0" - elif self._device.startswith("cuda:"): - if self._device != "cuda:0": - logger.debug( - f"SelectPrims only supports cuda:0. Using cuda:0 for SelectPrims " - f"even though simulation device is {self._device}." - ) - fabric_device = "cuda:0" - - # Build the bidirectional view ↔ fabric index mapping - index_selection = fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - ], - device=fabric_device, - ) - - self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32).to(fabric_device) - fabric_to_view = wp.fabricarray(index_selection, self._view_index_attr) - - wp.launch( - kernel=fabric_utils.set_view_to_fabric_array, - dim=fabric_to_view.shape[0], - inputs=[fabric_to_view, self._view_to_fabric], - device=fabric_device, - ) - wp.synchronize() + self._world_selection = self._select_single_attr(self._WORLD_MATRIX_ATTR) + self._local_selection = self._select_single_attr(self._LOCAL_MATRIX_ATTR) # Pre-allocate reusable output buffers (world poses) self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) @@ -185,11 +158,21 @@ def initialize(self) -> None: # Dummy buffer for unused kernel outputs (always empty) self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - self._local_selection = None - self._fabric_local_array = None + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ - self._fabric_stage = fabric_stage - self._fabric_device = fabric_device + @property + def count(self) -> int: + """Number of prims managed by this backend.""" + return len(self._prims) + + @property + def prim_paths(self) -> list[str]: + """Prim path strings (lazily cached).""" + if not hasattr(self, "_prim_paths_cache"): + self._prim_paths_cache = [p.GetPath().pathString for p in self._prims] + return self._prim_paths_cache # ------------------------------------------------------------------ # Setters @@ -202,45 +185,16 @@ def set_world_poses( indices: Sequence[int] | None = None, ) -> None: """Write world poses to Fabric ``omni:fabric:worldMatrix`` via a Warp kernel.""" - self.initialize() + # if local transforms were set, we need to update the world transforms if self._stage_id in FabricBackend._dirty_stages: self._fabric_hierarchy.update_world_xforms() FabricBackend._dirty_stages.discard(self._stage_id) - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - if positions is not None: - positions_wp = wp.from_torch(positions) - else: - positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - if orientations is not None: - orientations_wp = wp.from_torch(orientations) - else: - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - - scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - world_matrices = self._get_world_matrices_as_fabricarray() - - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, - orientations_wp, - scales_wp, - False, - False, - False, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, + fabric_indices = self._convert_view_to_fabric_indices(indices) + self._compose_transforms( + self._get_world_indexed_fabricarray(), fabric_indices, positions=positions, orientations=orientations ) - wp.synchronize() def set_local_poses( self, @@ -254,73 +208,17 @@ def set_local_poses( :pyobj:`IFabricHierarchy` and marks world transforms as dirty so that a subsequent read will propagate the change. """ - self.initialize() - - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - if translations is not None: - translations_wp = wp.from_torch(translations) - else: - translations_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - if orientations is not None: - orientations_wp = wp.from_torch(orientations) - else: - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - - scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - local_matrices = self._get_local_matrices_as_fabricarray() - - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - local_matrices, - translations_wp, - orientations_wp, - scales_wp, - False, - False, - False, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, + fabric_indices = self._convert_view_to_fabric_indices(indices) + self._compose_transforms( + self._get_local_indexed_fabricarray(), fabric_indices, positions=translations, orientations=orientations ) - wp.synchronize() FabricBackend._dirty_stages.add(self._stage_id) def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) -> None: """Write scales into the Fabric world matrix via a Warp kernel.""" - self.initialize() - - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - scales_wp = wp.from_torch(scales) - positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - world_matrices = self._get_world_matrices_as_fabricarray() - - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, - orientations_wp, - scales_wp, - False, - False, - False, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - wp.synchronize() + fabric_indices = self._convert_view_to_fabric_indices(indices) + self._compose_transforms(self._get_world_indexed_fabricarray(), fabric_indices, scales=scales) # ------------------------------------------------------------------ # Getters @@ -328,95 +226,57 @@ def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: """Read world poses from Fabric and decompose via a Warp kernel.""" - self.initialize() if self._stage_id in FabricBackend._dirty_stages: self._fabric_hierarchy.update_world_xforms() FabricBackend._dirty_stages.discard(self._stage_id) - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] + fabric_indices = self._convert_view_to_fabric_indices(indices) + count = fabric_indices.shape[0] + dummy = self._fabric_dummy_buffer use_cached_buffers = indices is None or indices == slice(None) if use_cached_buffers: positions_wp = self._fabric_positions_buffer orientations_wp = self._fabric_orientations_buffer - scales_wp = self._fabric_dummy_buffer else: positions_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) - scales_wp = self._fabric_dummy_buffer - - world_matrices = self._get_world_matrices_as_fabricarray() - wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, - orientations_wp, - scales_wp, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, + self._decompose_transforms( + self._get_world_indexed_fabricarray(), fabric_indices, positions_wp, orientations_wp, dummy ) if use_cached_buffers: - wp.synchronize() return self._fabric_positions_torch, self._fabric_orientations_torch - else: - positions = wp.to_torch(positions_wp) - orientations = wp.to_torch(orientations_wp) - return positions, orientations + return wp.to_torch(positions_wp), wp.to_torch(orientations_wp) def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: """Read local poses from Fabric and decompose via a Warp kernel.""" - self.initialize() - - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] + fabric_indices = self._convert_view_to_fabric_indices(indices) + count = fabric_indices.shape[0] + dummy = self._fabric_dummy_buffer use_cached_buffers = indices is None or indices == slice(None) if use_cached_buffers: translations_wp = self._fabric_local_translations_buffer orientations_wp = self._fabric_local_orientations_buffer - scales_wp = self._fabric_dummy_buffer else: translations_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) - scales_wp = self._fabric_dummy_buffer - - local_matrices = self._get_local_matrices_as_fabricarray() - wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, - inputs=[ - local_matrices, - translations_wp, - orientations_wp, - scales_wp, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, + self._decompose_transforms( + self._get_local_indexed_fabricarray(), fabric_indices, translations_wp, orientations_wp, dummy ) if use_cached_buffers: - wp.synchronize() return self._fabric_local_translations_torch, self._fabric_local_orientations_torch - else: - translations = wp.to_torch(translations_wp) - orientations = wp.to_torch(orientations_wp) - return translations, orientations + return wp.to_torch(translations_wp), wp.to_torch(orientations_wp) def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: """Read scales from Fabric world matrices and extract via a Warp kernel.""" - self.initialize() - - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] + fabric_indices = self._convert_view_to_fabric_indices(indices) + count = fabric_indices.shape[0] + dummy = self._fabric_dummy_buffer use_cached_buffers = indices is None or indices == slice(None) if use_cached_buffers: @@ -424,80 +284,152 @@ def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: else: scales_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) - positions_wp = self._fabric_dummy_buffer - orientations_wp = self._fabric_dummy_buffer - world_matrices = self._get_world_matrices_as_fabricarray() + self._decompose_transforms(self._get_world_indexed_fabricarray(), fabric_indices, dummy, dummy, scales_wp) + + if use_cached_buffers: + return self._fabric_scales_torch + return wp.to_torch(scales_wp) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _compose_transforms( + self, + matrices: wp.indexedfabricarray, + indices_wp: wp.array, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + scales: torch.Tensor | None = None, + ) -> None: + """Launch the compose kernel to write transform components into Fabric matrices. + + Converts non-``None`` torch tensors to Warp arrays and substitutes a + pre-allocated zero-length dummy for omitted components so the kernel + leaves existing values untouched. + """ + dummy = self._fabric_dummy_buffer + positions_wp = wp.from_torch(positions) if positions is not None else dummy + orientations_wp = wp.from_torch(orientations) if orientations is not None else dummy + scales_wp = wp.from_torch(scales) if scales is not None else dummy wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], inputs=[ - world_matrices, + matrices, positions_wp, orientations_wp, scales_wp, + False, + False, + False, indices_wp, - self._view_to_fabric, ], device=self._fabric_device, ) + wp.synchronize() - if use_cached_buffers: - wp.synchronize() - return self._fabric_scales_torch - else: - return wp.to_torch(scales_wp) + def _decompose_transforms( + self, + matrices: wp.indexedfabricarray, + indices_wp: wp.array, + positions_wp: wp.array, + orientations_wp: wp.array, + scales_wp: wp.array, + ) -> None: + """Launch the decompose kernel to read transform components from Fabric matrices.""" + wp.launch( + kernel=fabric_utils.decompose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[matrices, positions_wp, orientations_wp, scales_wp, indices_wp], + device=self._fabric_device, + ) + wp.synchronize() - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ + def _build_view_to_fabric_index_mapping(self) -> None: + """Build the view index to fabric index array from PrimSelection path ordering.""" - def _get_world_matrices_as_fabricarray(self) -> wp.fabricarray: - """Create a fresh :class:`wp.fabricarray` backed by ``omni:fabric:worldMatrix``. + self._view_indices = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) - Rebuilding both the :class:`PrimSelection` and fabricarray each call - ensures Fabric's journaling marks the attribute dirty for downstream - consumers (renderers, etc.). - """ - import usdrt + # Assign to each prim an index + fabric_paths = self._index_selection.GetPaths() + path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} - if True: - self._world_selection = self._fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), - ], - device=self._fabric_device, - ) - self._fabric_array = wp.fabricarray(self._world_selection, self._WORLD_MATRIX_ATTR) - else: - self._world_selection.PrepareForReuse() + # Look up the index for each prim observed by this view + fabric_indices: list[int] = [] + for prim_path in self.prim_paths: + fabric_idx = path_to_fabric_idx.get(prim_path) + if fabric_idx is None: + raise RuntimeError( + f"Prim '{prim_path}' not found in Fabric selection. Ensure the hierarchy has been populated." + ) + fabric_indices.append(fabric_idx) - return self._fabric_array + self._fabric_indices = wp.array(fabric_indices, dtype=wp.int32).to(self._fabric_device) - def _get_local_matrices_as_fabricarray(self) -> wp.fabricarray: - """Create a fresh :class:`wp.fabricarray` backed by ``omni:fabric:localMatrix``.""" + def _select_single_attr(self, attr_name: str): + """Create a ``PrimSelection`` requiring only *one* matrix attribute.""" import usdrt - if True: - self._local_selection = self._fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite), - ], - device=self._fabric_device, + selection = self._fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.Matrix4d, attr_name, usdrt.Usd.Access.ReadWrite), + ], + device=self._fabric_device, + want_paths=True, + ) + + self._maybe_rebuild_fabric_indices(selection) + return selection + + def _maybe_rebuild_fabric_indices(self, selection) -> None: + """Rebuild ``_fabric_indices`` if the selection ordering changed.""" + sel_paths = selection.GetPaths() + path_to_idx: dict[str, int] = {str(p): i for i, p in enumerate(sel_paths)} + + needs_rebuild = False + current_indices = self._fabric_indices.numpy() + for view_idx, prim_path in enumerate(self.prim_paths): + new_fabric_idx = path_to_idx.get(prim_path) + if new_fabric_idx is None: + raise RuntimeError(f"Prim '{prim_path}' disappeared from Fabric selection after topology change.") + if new_fabric_idx != current_indices[view_idx]: + needs_rebuild = True + break + + if needs_rebuild: + logger.debug( + "Fabric topology changed — rebuilding fabric indices for %d prims.", + self.count, ) - self._fabric_local_array = wp.fabricarray(self._local_selection, self._LOCAL_MATRIX_ATTR) - else: - self._local_selection.PrepareForReuse() + new_indices = [path_to_idx[p] for p in self.prim_paths] + self._fabric_indices = wp.array(new_indices, dtype=wp.int32).to(self._fabric_device) + + def _get_world_indexed_fabricarray(self) -> wp.indexedfabricarray: + # if self._world_selection.PrepareForReuse(): + self._world_selection = self._select_single_attr(self._WORLD_MATRIX_ATTR) + fa = wp.fabricarray(self._world_selection, self._WORLD_MATRIX_ATTR) + return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) + + def _get_local_indexed_fabricarray(self) -> wp.indexedfabricarray: + # if self._local_selection.PrepareForReuse(): + self._local_selection = self._select_single_attr(self._LOCAL_MATRIX_ATTR) + fa = wp.fabricarray(self._local_selection, self._LOCAL_MATRIX_ATTR) + return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) - return self._fabric_local_array + def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.array: + """Convert requested view indices to fabric indices. - def _resolve_indices_wp(self, indices: Sequence[int] | None) -> wp.array: - """Convert view indices to a Warp :class:`wp.array`.""" + Args: + indices: Requested path indices. If None, then all indices are used. + + Returns: + A warp array of fabric indices. + """ if indices is None or indices == slice(None): - if self._default_view_indices is None: + if self._view_indices is None: raise RuntimeError("Fabric indices are not initialized.") - return self._default_view_indices + return self._view_indices indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) return wp.array(indices_list, dtype=wp.uint32).to(self._device) diff --git a/source/isaaclab/isaaclab/utils/warp/fabric.py b/source/isaaclab/isaaclab/utils/warp/fabric.py index a48f773f4991..e38ea0e7b337 100644 --- a/source/isaaclab/isaaclab/utils/warp/fabric.py +++ b/source/isaaclab/isaaclab/utils/warp/fabric.py @@ -18,12 +18,14 @@ if TYPE_CHECKING: FabricArrayUInt32 = Any FabricArrayMat44d = Any + IndexedFabricArrayMat44d = Any ArrayUInt32 = Any ArrayUInt32_1d = Any ArrayFloat32_2d = Any else: FabricArrayUInt32 = wp.fabricarray(dtype=wp.uint32) FabricArrayMat44d = wp.fabricarray(dtype=wp.mat44d) + IndexedFabricArrayMat44d = wp.indexedfabricarray(dtype=wp.mat44d) ArrayUInt32 = wp.array(ndim=1, dtype=wp.uint32) ArrayUInt32_1d = wp.array(dtype=wp.uint32) ArrayFloat32_2d = wp.array(ndim=2, dtype=wp.float32) @@ -37,13 +39,6 @@ def set_view_to_fabric_array(fabric_to_view: FabricArrayUInt32, view_to_fabric: view_to_fabric[view_idx] = wp.uint32(fabric_idx) -@wp.kernel(enable_backward=False) -def arange_k(a: ArrayUInt32_1d): - """Fill array with sequential indices.""" - tid = int(wp.tid()) - a[tid] = wp.uint32(tid) - - @wp.kernel(enable_backward=False) def decompose_fabric_transformation_matrix_to_warp_arrays( fabric_matrices: FabricArrayMat44d, @@ -163,6 +158,109 @@ def compose_fabric_transformation_matrix_from_warp_arrays( ) +@wp.kernel(enable_backward=False) +def decompose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + indices: ArrayUInt32, +): + """Decompose indexed Fabric transformation matrices into position, orientation, and scale. + + Like :func:`decompose_fabric_transformation_matrix_to_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices. + array_positions: Output array for positions [m], shape (N, 3). + array_orientations: Output array for quaternions in xyzw format, shape (N, 4). + array_scales: Output array for scales, shape (N, 3). + indices: View indices to process (subset selection). + """ + output_index = wp.tid() + view_index = indices[output_index] + + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + array_positions[output_index, 0] = position[0] + array_positions[output_index, 1] = position[1] + array_positions[output_index, 2] = position[2] + if array_orientations.shape[0] > 0: + array_orientations[output_index, 0] = rotation[0] + array_orientations[output_index, 1] = rotation[1] + array_orientations[output_index, 2] = rotation[2] + array_orientations[output_index, 3] = rotation[3] + if array_scales.shape[0] > 0: + array_scales[output_index, 0] = scale[0] + array_scales[output_index, 1] = scale[1] + array_scales[output_index, 2] = scale[2] + + +@wp.kernel(enable_backward=False) +def compose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + broadcast_positions: bool, + broadcast_orientations: bool, + broadcast_scales: bool, + indices: ArrayUInt32, +): + """Compose indexed Fabric transformation matrices from position, orientation, and scale. + + Like :func:`compose_fabric_transformation_matrix_from_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices to update. + array_positions: Input array for positions [m], shape (N, 3). + array_orientations: Input array for quaternions in xyzw format, shape (N, 4). + array_scales: Input array for scales, shape (N, 3). + broadcast_positions: If True, use first position for all prims. + broadcast_orientations: If True, use first orientation for all prims. + broadcast_scales: If True, use first scale for all prims. + indices: View indices to process (subset selection). + """ + i = wp.tid() + view_index = indices[i] + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + if broadcast_positions: + index = 0 + else: + index = i + position[0] = array_positions[index, 0] + position[1] = array_positions[index, 1] + position[2] = array_positions[index, 2] + if array_orientations.shape[0] > 0: + if broadcast_orientations: + index = 0 + else: + index = i + rotation[0] = array_orientations[index, 0] + rotation[1] = array_orientations[index, 1] + rotation[2] = array_orientations[index, 2] + rotation[3] = array_orientations[index, 3] + if array_scales.shape[0] > 0: + if broadcast_scales: + index = 0 + else: + index = i + scale[0] = array_scales[index, 0] + scale[1] = array_scales[index, 1] + scale[2] = array_scales[index, 2] + + fabric_matrices[view_index] = wp.mat44d( # type: ignore[arg-type] + wp.transpose(wp.transform_compose(position, rotation, scale)) # type: ignore[arg-type] + ) + + @wp.func def _decompose_transformation_matrix(m: Any): # -> tuple[wp.vec3f, wp.quatf, wp.vec3f] """Decompose a 4x4 transformation matrix into position, orientation, and scale. From 6f5294495d5e378ef33d304613e45053c3d3b859 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Wed, 1 Apr 2026 18:31:11 -0700 Subject: [PATCH 14/20] Rebuild indexing based on Fabric topology change --- .../sim/views/xform_fabric_backend.py | 97 +++++++++---------- .../test/sim/test_views_xform_prim.py | 65 +++++++++++++ 2 files changed, 109 insertions(+), 53 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index aeeda86c260a..14051fb9159f 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -74,10 +74,9 @@ def __init__( self._prims = prims self._device = device - # Lazy-initialized state (None until initialize() runs) + # Lazy-initialized state (None until __init__ body completes) self._view_indices: wp.array | None = None self._fabric_indices: wp.array | None = None - self._selection = None # Resolve the Fabric device string (SelectPrims only supports cuda:0) if self._device.startswith("cuda"): @@ -132,12 +131,19 @@ def __init__( self._fabric_hierarchy = FabricBackend._hierarchy_cache[stage_id] self._stage_id = stage_id + # Index selection for all primitives populated by the hierarchy, used for tracking topology changes. + self._index_selection = fabric_stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.Matrix4d, self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read), + (usdrt.Sdf.ValueTypeNames.Matrix4d, self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read), + ], + device=fabric_device, + want_paths=True, + ) + # Build the view → fabric index array from PrimSelection path ordering. # Default view-index array [0, 1, ..., count-1] for "all prims". - self._build_view_to_fabric_index_mapping() - - self._world_selection = self._select_single_attr(self._WORLD_MATRIX_ATTR) - self._local_selection = self._select_single_attr(self._LOCAL_MATRIX_ATTR) + self._rebuild_view_to_fabric_index_mapping(force_rebuild=True) # Pre-allocate reusable output buffers (world poses) self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) @@ -193,7 +199,7 @@ def set_world_poses( fabric_indices = self._convert_view_to_fabric_indices(indices) self._compose_transforms( - self._get_world_indexed_fabricarray(), fabric_indices, positions=positions, orientations=orientations + self._get_world_rw_array(), fabric_indices, positions=positions, orientations=orientations ) def set_local_poses( @@ -210,7 +216,7 @@ def set_local_poses( """ fabric_indices = self._convert_view_to_fabric_indices(indices) self._compose_transforms( - self._get_local_indexed_fabricarray(), fabric_indices, positions=translations, orientations=orientations + self._get_local_rw_array(), fabric_indices, positions=translations, orientations=orientations ) FabricBackend._dirty_stages.add(self._stage_id) @@ -218,7 +224,7 @@ def set_local_poses( def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) -> None: """Write scales into the Fabric world matrix via a Warp kernel.""" fabric_indices = self._convert_view_to_fabric_indices(indices) - self._compose_transforms(self._get_world_indexed_fabricarray(), fabric_indices, scales=scales) + self._compose_transforms(self._get_world_rw_array(), fabric_indices, scales=scales) # ------------------------------------------------------------------ # Getters @@ -243,7 +249,7 @@ def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) self._decompose_transforms( - self._get_world_indexed_fabricarray(), fabric_indices, positions_wp, orientations_wp, dummy + self._get_world_ro_array(), fabric_indices, positions_wp, orientations_wp, dummy ) if use_cached_buffers: @@ -265,7 +271,7 @@ def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) self._decompose_transforms( - self._get_local_indexed_fabricarray(), fabric_indices, translations_wp, orientations_wp, dummy + self._get_local_ro_array(), fabric_indices, translations_wp, orientations_wp, dummy ) if use_cached_buffers: @@ -284,7 +290,7 @@ def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: else: scales_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) - self._decompose_transforms(self._get_world_indexed_fabricarray(), fabric_indices, dummy, dummy, scales_wp) + self._decompose_transforms(self._get_world_ro_array(), fabric_indices, dummy, dummy, scales_wp) if use_cached_buffers: return self._fabric_scales_torch @@ -347,9 +353,14 @@ def _decompose_transforms( ) wp.synchronize() - def _build_view_to_fabric_index_mapping(self) -> None: + def _rebuild_view_to_fabric_index_mapping(self, force_rebuild: bool = False) -> None: """Build the view index to fabric index array from PrimSelection path ordering.""" + # Rebuild indexing only when fabric topology has changed or whenever forced. + topology_changed = self._index_selection.PrepareForReuse() + if not (topology_changed or force_rebuild): + return + self._view_indices = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) # Assign to each prim an index @@ -368,55 +379,35 @@ def _build_view_to_fabric_index_mapping(self) -> None: self._fabric_indices = wp.array(fabric_indices, dtype=wp.int32).to(self._fabric_device) - def _select_single_attr(self, attr_name: str): - """Create a ``PrimSelection`` requiring only *one* matrix attribute.""" + def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: + """Create an indexed fabric array for a single attribute with the given access mode.""" import usdrt - selection = self._fabric_stage.SelectPrims( require_attrs=[ - (usdrt.Sdf.ValueTypeNames.Matrix4d, attr_name, usdrt.Usd.Access.ReadWrite), + (usdrt.Sdf.ValueTypeNames.Matrix4d, attr_name, access), ], device=self._fabric_device, - want_paths=True, ) + fa = wp.fabricarray(selection, attr_name) + return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) - self._maybe_rebuild_fabric_indices(selection) - return selection - - def _maybe_rebuild_fabric_indices(self, selection) -> None: - """Rebuild ``_fabric_indices`` if the selection ordering changed.""" - sel_paths = selection.GetPaths() - path_to_idx: dict[str, int] = {str(p): i for i, p in enumerate(sel_paths)} - - needs_rebuild = False - current_indices = self._fabric_indices.numpy() - for view_idx, prim_path in enumerate(self.prim_paths): - new_fabric_idx = path_to_idx.get(prim_path) - if new_fabric_idx is None: - raise RuntimeError(f"Prim '{prim_path}' disappeared from Fabric selection after topology change.") - if new_fabric_idx != current_indices[view_idx]: - needs_rebuild = True - break - - if needs_rebuild: - logger.debug( - "Fabric topology changed — rebuilding fabric indices for %d prims.", - self.count, - ) - new_indices = [path_to_idx[p] for p in self.prim_paths] - self._fabric_indices = wp.array(new_indices, dtype=wp.int32).to(self._fabric_device) + def _get_world_ro_array(self) -> wp.indexedfabricarray: + import usdrt + return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read) - def _get_world_indexed_fabricarray(self) -> wp.indexedfabricarray: - # if self._world_selection.PrepareForReuse(): - self._world_selection = self._select_single_attr(self._WORLD_MATRIX_ATTR) - fa = wp.fabricarray(self._world_selection, self._WORLD_MATRIX_ATTR) - return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) + def _get_world_rw_array(self) -> wp.indexedfabricarray: + import usdrt + self._rebuild_view_to_fabric_index_mapping() + return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) - def _get_local_indexed_fabricarray(self) -> wp.indexedfabricarray: - # if self._local_selection.PrepareForReuse(): - self._local_selection = self._select_single_attr(self._LOCAL_MATRIX_ATTR) - fa = wp.fabricarray(self._local_selection, self._LOCAL_MATRIX_ATTR) - return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) + def _get_local_ro_array(self) -> wp.indexedfabricarray: + import usdrt + return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read) + + def _get_local_rw_array(self) -> wp.indexedfabricarray: + import usdrt + self._rebuild_view_to_fabric_index_mapping() + return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.array: """Convert requested view indices to fabric indices. diff --git a/source/isaaclab/test/sim/test_views_xform_prim.py b/source/isaaclab/test/sim/test_views_xform_prim.py index 2de28b78b5a0..77dc46763d24 100644 --- a/source/isaaclab/test/sim/test_views_xform_prim.py +++ b/source/isaaclab/test/sim/test_views_xform_prim.py @@ -1511,3 +1511,68 @@ def test_fabric_usd_consistency(device): pos_fabric_after, quat_fabric_after = view_fabric.get_world_poses() torch.testing.assert_close(pos_fabric_after, new_positions, atol=1e-4, rtol=0) torch.testing.assert_close(quat_fabric_after, new_orientations, atol=1e-4, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_fabric_topology_change_write_lands_on_correct_prim(device): + """Test that writes land on the correct prim after Fabric topology changes. + + Creates 3 Camera prims, builds a view for Cam_0 and Cam_2 only, then + adds a custom attribute to Cam_1 (forcing a bucket reorganization in + Fabric). After that, writes new positions through the view and verifies + via raw usdrt reads that the correct prims received the data. + """ + import usdrt + + _skip_if_backend_unavailable("fabric", device) + + stage = sim_utils.get_current_stage() + + # Step 1: create 3 Camera prims with known positions + sim_utils.create_prim("/World/Cam_0", "Camera", translation=(1.0, 0.0, 0.0), stage=stage) + sim_utils.create_prim("/World/Cam_1", "Camera", translation=(2.0, 0.0, 0.0), stage=stage) + sim_utils.create_prim("/World/Cam_2", "Camera", translation=(3.0, 0.0, 0.0), stage=stage) + + # Step 2: create a Fabric view for Cam_0 and Cam_2 only + sim_ctx = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view = XformPrimView("/World/Cam_[02]", device=device) + assert view.count == 2 + assert set(view.prim_paths) == {"/World/Cam_0", "/World/Cam_2"} + + # Step 3: add a custom attribute to Cam_1 to force a bucket change + fab_stage = usdrt.Usd.Stage.Attach(sim_utils.get_current_stage_id()) + cam1 = fab_stage.GetPrimAtPath(usdrt.Sdf.Path("/World/Cam_1")) + cam1.CreateAttribute("custom:dummyFloat", usdrt.Sdf.ValueTypeNames.Float, True).Set(42.0) + + # Step 4: write new positions through the view + new_positions = torch.tensor( + [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], + dtype=torch.float32, + device=device, + ) + view.set_world_poses(positions=new_positions) + + # Step 5: verify via raw usdrt that the correct prims got the new data + def _read_fabric_position(prim_path: str) -> tuple[float, float, float]: + prim = fab_stage.GetPrimAtPath(usdrt.Sdf.Path(prim_path)) + mat = prim.GetAttribute("omni:fabric:worldMatrix").Get() + return (mat[3][0], mat[3][1], mat[3][2]) + + pos_cam0 = _read_fabric_position("/World/Cam_0") + pos_cam1 = _read_fabric_position("/World/Cam_1") + pos_cam2 = _read_fabric_position("/World/Cam_2") + + # Cam_0 should have (10, 20, 30) + assert abs(pos_cam0[0] - 10.0) < 0.01, f"Cam_0 x: expected 10.0, got {pos_cam0[0]}" + assert abs(pos_cam0[1] - 20.0) < 0.01, f"Cam_0 y: expected 20.0, got {pos_cam0[1]}" + assert abs(pos_cam0[2] - 30.0) < 0.01, f"Cam_0 z: expected 30.0, got {pos_cam0[2]}" + + # Cam_2 should have (40, 50, 60) + assert abs(pos_cam2[0] - 40.0) < 0.01, f"Cam_2 x: expected 40.0, got {pos_cam2[0]}" + assert abs(pos_cam2[1] - 50.0) < 0.01, f"Cam_2 y: expected 50.0, got {pos_cam2[1]}" + assert abs(pos_cam2[2] - 60.0) < 0.01, f"Cam_2 z: expected 60.0, got {pos_cam2[2]}" + + # Cam_1 should NOT have been touched (still at its original position) + assert abs(pos_cam1[0] - 2.0) < 0.01 or abs(pos_cam1[0]) < 0.01, ( + f"Cam_1 x: expected ~2.0 or ~0.0, got {pos_cam1[0]} — write leaked to wrong prim!" + ) From 32ea1499d1ffcb2ee7038541e4043c8080345ac4 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Wed, 1 Apr 2026 23:22:19 -0700 Subject: [PATCH 15/20] Small refactoring --- .../sim/views/xform_fabric_backend.py | 128 +++++++++--------- 1 file changed, 65 insertions(+), 63 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index 14051fb9159f..1fd30b5e1753 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -74,10 +74,6 @@ def __init__( self._prims = prims self._device = device - # Lazy-initialized state (None until __init__ body completes) - self._view_indices: wp.array | None = None - self._fabric_indices: wp.array | None = None - # Resolve the Fabric device string (SelectPrims only supports cuda:0) if self._device.startswith("cuda"): if self._device == "cuda": @@ -87,63 +83,65 @@ def __init__( "SelectPrims only supports cuda:0. Using cuda:0 even though simulation device is %s.", self._device, ) - fabric_device = "cuda:0" + device = "cuda:0" else: - fabric_device = self._device - self._fabric_device = fabric_device + device = self._device + self._device = device import usdrt from usdrt import Rt # noqa: F401 — imported for side-effects - stage_id = sim_utils.get_current_stage_id() - fabric_stage = usdrt.Usd.Stage.Attach(stage_id) - self._fabric_stage = fabric_stage + self._stage_id = sim_utils.get_current_stage_id() + self._stage = usdrt.Usd.Stage.Attach(self._stage_id) # Reuse (or create) a hierarchy handle for this stage. - if stage_id not in FabricBackend._hierarchy_cache: + if self._stage_id not in FabricBackend._hierarchy_cache: pop = usdrt.population.IUtils() - pop.set_enable_usd_notice_handling( - fabric_stage.GetStageIdAsStageId(), - fabric_stage.GetFabricId(), - True, - ) pop.populate_from_usd( - fabric_stage.GetStageReaderWriterId(), - fabric_stage.GetStageIdAsStageId(), + self._stage.GetStageReaderWriterId(), + self._stage.GetStageIdAsStageId(), usdrt.Sdf.Path("/"), 0, ) - pop.apply_pending_usd_updates( - fabric_stage.GetStageIdAsStageId(), - fabric_stage.GetStageReaderWriterId(), - 0, - ) - hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( - fabric_stage.GetFabricId(), - fabric_stage.GetStageIdAsStageId(), + self._stage.GetFabricId(), + self._stage.GetStageIdAsStageId(), ) hierarchy.update_world_xforms() hierarchy.track_local_xform_changes(True) hierarchy.track_world_xform_changes(True) - FabricBackend._hierarchy_cache[stage_id] = hierarchy - - self._fabric_hierarchy = FabricBackend._hierarchy_cache[stage_id] - self._stage_id = stage_id - - # Index selection for all primitives populated by the hierarchy, used for tracking topology changes. - self._index_selection = fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.Matrix4d, self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read), - ], - device=fabric_device, + FabricBackend._hierarchy_cache[self._stage_id] = hierarchy + + self._fabric_hierarchy = FabricBackend._hierarchy_cache[self._stage_id] + + matrix = usdrt.Sdf.ValueTypeNames.Matrix4d + ro = usdrt.Usd.Access.Read + rw = usdrt.Usd.Access.ReadWrite + world_matrix_ro = (matrix, self._WORLD_MATRIX_ATTR, ro) + local_matrix_ro = (matrix, self._LOCAL_MATRIX_ATTR, ro) + world_matrix_rw = (matrix, self._WORLD_MATRIX_ATTR, rw) + local_matrix_rw = (matrix, self._LOCAL_MATRIX_ATTR, rw) + + # Persistent selections — one per (attribute x access-mode) combination. + # PrepareForReuse() is called before each use to detect topology changes. + self._index_selection_ro = self._stage.SelectPrims( + require_attrs=[world_matrix_ro, local_matrix_ro], + device=device, want_paths=True, ) + self._world_selection_rw = self._stage.SelectPrims(require_attrs=[world_matrix_rw], device=device) + self._local_selection_rw = self._stage.SelectPrims(require_attrs=[local_matrix_rw], device=device) + + # Cached indexed fabric arrays (rebuilt when topology changes) + self._world_ifa_ro: wp.indexedfabricarray | None = None + self._local_ifa_ro: wp.indexedfabricarray | None = None + self._world_ifa_rw: wp.indexedfabricarray | None = None + self._local_ifa_rw: wp.indexedfabricarray | None = None # Build the view → fabric index array from PrimSelection path ordering. # Default view-index array [0, 1, ..., count-1] for "all prims". - self._rebuild_view_to_fabric_index_mapping(force_rebuild=True) + self._view_indices: wp.array = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) + self._fabric_indices: wp.array = self._compute_fabric_indices() # Pre-allocate reusable output buffers (world poses) self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) @@ -248,9 +246,7 @@ def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T positions_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) - self._decompose_transforms( - self._get_world_ro_array(), fabric_indices, positions_wp, orientations_wp, dummy - ) + self._decompose_transforms(self._get_world_ro_array(), fabric_indices, positions_wp, orientations_wp, dummy) if use_cached_buffers: return self._fabric_positions_torch, self._fabric_orientations_torch @@ -270,9 +266,7 @@ def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T translations_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) - self._decompose_transforms( - self._get_local_ro_array(), fabric_indices, translations_wp, orientations_wp, dummy - ) + self._decompose_transforms(self._get_local_ro_array(), fabric_indices, translations_wp, orientations_wp, dummy) if use_cached_buffers: return self._fabric_local_translations_torch, self._fabric_local_orientations_torch @@ -332,7 +326,7 @@ def _compose_transforms( False, indices_wp, ], - device=self._fabric_device, + device=self._device, ) wp.synchronize() @@ -349,22 +343,13 @@ def _decompose_transforms( kernel=fabric_utils.decompose_indexed_fabric_transforms, dim=indices_wp.shape[0], inputs=[matrices, positions_wp, orientations_wp, scales_wp, indices_wp], - device=self._fabric_device, + device=self._device, ) wp.synchronize() - def _rebuild_view_to_fabric_index_mapping(self, force_rebuild: bool = False) -> None: - """Build the view index to fabric index array from PrimSelection path ordering.""" - - # Rebuild indexing only when fabric topology has changed or whenever forced. - topology_changed = self._index_selection.PrepareForReuse() - if not (topology_changed or force_rebuild): - return - - self._view_indices = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) - + def _compute_fabric_indices(self) -> wp.array: # Assign to each prim an index - fabric_paths = self._index_selection.GetPaths() + fabric_paths = self._index_selection_ro.GetPaths() path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} # Look up the index for each prim observed by this view @@ -377,36 +362,53 @@ def _rebuild_view_to_fabric_index_mapping(self, force_rebuild: bool = False) -> ) fabric_indices.append(fabric_idx) - self._fabric_indices = wp.array(fabric_indices, dtype=wp.int32).to(self._fabric_device) + return wp.array(fabric_indices, dtype=wp.int32).to(self._device) + + def _ensure_fabric_indices_are_up_to_date(self, force_rebuild: bool = False) -> None: + """Build the view index to fabric index array from PrimSelection path ordering.""" + + # Rebuild indexing only when fabric topology has changed or whenever forced. + topology_changed = self._index_selection_ro.PrepareForReuse() + if not (topology_changed or force_rebuild): + return + + self._fabric_indices = self._compute_fabric_indices() def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: """Create an indexed fabric array for a single attribute with the given access mode.""" import usdrt - selection = self._fabric_stage.SelectPrims( + + selection = self._stage.SelectPrims( require_attrs=[ (usdrt.Sdf.ValueTypeNames.Matrix4d, attr_name, access), ], - device=self._fabric_device, + device=self._device, ) fa = wp.fabricarray(selection, attr_name) return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) def _get_world_ro_array(self) -> wp.indexedfabricarray: import usdrt + + self._ensure_fabric_indices_are_up_to_date() return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read) def _get_world_rw_array(self) -> wp.indexedfabricarray: import usdrt - self._rebuild_view_to_fabric_index_mapping() + + self._ensure_fabric_indices_are_up_to_date() return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) def _get_local_ro_array(self) -> wp.indexedfabricarray: import usdrt + + self._ensure_fabric_indices_are_up_to_date() return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read) def _get_local_rw_array(self) -> wp.indexedfabricarray: import usdrt - self._rebuild_view_to_fabric_index_mapping() + + self._ensure_fabric_indices_are_up_to_date() return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.array: From 4c2ac4a8147f30d6fa14dc5020bbcddb7ce86399 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Fri, 3 Apr 2026 13:30:03 -0700 Subject: [PATCH 16/20] Update to reuse buffers through PrepareForReuse --- .../benchmarks/benchmark_xform_prim_view.py | 5 +- source/isaaclab/isaaclab/sim/utils/stage.py | 50 ++++++++++-- .../sim/views/xform_fabric_backend.py | 76 +++++++++++-------- .../test/sim/test_views_xform_prim.py | 9 ++- 4 files changed, 98 insertions(+), 42 deletions(-) diff --git a/scripts/benchmarks/benchmark_xform_prim_view.py b/scripts/benchmarks/benchmark_xform_prim_view.py index e76796e20271..5a9b5884fa64 100644 --- a/scripts/benchmarks/benchmark_xform_prim_view.py +++ b/scripts/benchmarks/benchmark_xform_prim_view.py @@ -54,6 +54,7 @@ AppLauncher.add_app_launcher_args(parser) args_cli = parser.parse_args() +args_cli.enable_cameras = True # launch omniverse app app_launcher = AppLauncher(args_cli) @@ -105,7 +106,9 @@ def benchmark_xform_prim_view( # noqa: C901 # Setup scene print(" Setting up scene") # Clear stage - sim_utils.create_new_stage() + + use_fabric: bool = "fabric" in api.lower() + sim_utils.create_new_stage(create_fabric_stage=use_fabric) # Create simulation context start_time = time.perf_counter() sim_cfg = sim_utils.SimulationCfg( diff --git a/source/isaaclab/isaaclab/sim/utils/stage.py b/source/isaaclab/isaaclab/sim/utils/stage.py index eb2b54d57cbd..723667a257d7 100644 --- a/source/isaaclab/isaaclab/sim/utils/stage.py +++ b/source/isaaclab/isaaclab/sim/utils/stage.py @@ -130,28 +130,65 @@ def _modify_path(asset_path: str) -> str: pass -def create_new_stage() -> Usd.Stage: +def create_new_stage(create_fabric_stage: bool = False) -> Usd.Stage: """Create a new in-memory USD stage. - Creates a new stage using pure USD (``Usd.Stage.CreateInMemory()``). + When ``create_fabric_stage`` is False (default), creates a pure USD stage + via ``Usd.Stage.CreateInMemory()``. + + When ``create_fabric_stage`` is True, creates the stage via + ``usdrt.Usd.Stage.CreateInMemory()`` which provides a paired Fabric store + alongside the USD stage. USD notice handling is enabled so that subsequent + prim creation on the USD stage automatically propagates into Fabric. Both + the USD stage and the USDRT stage handle are kept alive in a thread-local + context to prevent garbage collection from releasing the Fabric resources. If Kit is running and Kit extensions need to discover this stage (e.g. PhysX, ``isaacsim.core.prims.Articulation``), call :func:`attach_stage_to_usd_context` after scene setup. + Args: + create_fabric_stage: If True, create the stage through USDRT with a + backing Fabric store and enable USD-to-Fabric notice-driven sync. + Defaults to False. + Returns: - Usd.Stage: The created USD stage. + The created USD stage. Example: >>> import isaaclab.sim as sim_utils >>> >>> sim_utils.create_new_stage() - Usd.Stage.Open(rootLayer=Sdf.Find('anon:0x7fba6c04f840:World7.usd'), - sessionLayer=Sdf.Find('anon:0x7fba6c01c5c0:World7-session.usda'), - pathResolverContext=) + Usd.Stage.Open(rootLayer=Sdf.Find('anon:0x...'), ...) + >>> sim_utils.create_new_stage(create_fabric_stage=True) + Usd.Stage.Open(rootLayer=Sdf.Find('anon:0x...'), ...) """ + + if create_fabric_stage: + import uuid + + import usdrt + + rt_stage = usdrt.Usd.Stage.CreateInMemory(f"World_{uuid.uuid4().hex[:8]}.usda") + stage = UsdUtils.StageCache.Get().Find(Usd.StageCache.Id.FromLongInt(rt_stage.GetStageId())) + + # Storing stage in the context is required, so both stages aren't going to be garbage collected. + _context.stage = stage + _context.rt_stage = rt_stage + + stage_id = rt_stage.GetStageIdAsStageId() + fabric_id = rt_stage.GetFabricId() + srw_id = rt_stage.GetStageReaderWriterId() + + pop = usdrt.population.IUtils() + pop.set_enable_usd_notice_handling(stage_id, fabric_id, True) + pop.populate_from_usd(srw_id, stage_id, usdrt.Sdf.Path("/"), 0) + pop.apply_pending_usd_updates(stage_id, srw_id, 0) + return stage + stage: Usd.Stage = Usd.Stage.CreateInMemory() _context.stage = stage + _context.rt_stage = None UsdUtils.StageCache.Get().Insert(stage) return stage @@ -400,6 +437,7 @@ def close_stage() -> bool: stage_cache = UsdUtils.StageCache.Get() stage_cache.Clear() _context.stage = None + _context.rt_stage = None return True diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index 1fd30b5e1753..c230d499b0a1 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -77,7 +77,7 @@ def __init__( # Resolve the Fabric device string (SelectPrims only supports cuda:0) if self._device.startswith("cuda"): if self._device == "cuda": - logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.") + logger.info("Fabric device is not specified, defaulting to 'cuda:0'.") elif self._device != "cuda:0": logger.debug( "SelectPrims only supports cuda:0. Using cuda:0 even though simulation device is %s.", @@ -96,13 +96,7 @@ def __init__( # Reuse (or create) a hierarchy handle for this stage. if self._stage_id not in FabricBackend._hierarchy_cache: - pop = usdrt.population.IUtils() - pop.populate_from_usd( - self._stage.GetStageReaderWriterId(), - self._stage.GetStageIdAsStageId(), - usdrt.Sdf.Path("/"), - 0, - ) + self._stage.SynchronizeToFabric() hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( self._stage.GetFabricId(), self._stage.GetStageIdAsStageId(), @@ -124,7 +118,7 @@ def __init__( # Persistent selections — one per (attribute x access-mode) combination. # PrepareForReuse() is called before each use to detect topology changes. - self._index_selection_ro = self._stage.SelectPrims( + self._trans_selection_ro = self._stage.SelectPrims( require_attrs=[world_matrix_ro, local_matrix_ro], device=device, want_paths=True, @@ -132,17 +126,17 @@ def __init__( self._world_selection_rw = self._stage.SelectPrims(require_attrs=[world_matrix_rw], device=device) self._local_selection_rw = self._stage.SelectPrims(require_attrs=[local_matrix_rw], device=device) - # Cached indexed fabric arrays (rebuilt when topology changes) - self._world_ifa_ro: wp.indexedfabricarray | None = None - self._local_ifa_ro: wp.indexedfabricarray | None = None - self._world_ifa_rw: wp.indexedfabricarray | None = None - self._local_ifa_rw: wp.indexedfabricarray | None = None - # Build the view → fabric index array from PrimSelection path ordering. # Default view-index array [0, 1, ..., count-1] for "all prims". self._view_indices: wp.array = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) self._fabric_indices: wp.array = self._compute_fabric_indices() + # Cached indexed fabric arrays (rebuilt when topology changes). + self._world_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_selection_ro, self._WORLD_MATRIX_ATTR) + self._local_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_selection_ro, self._LOCAL_MATRIX_ATTR) + self._world_ifa_rw: wp.indexedfabricarray = self._build_array(self._world_selection_rw, self._WORLD_MATRIX_ATTR) + self._local_ifa_rw: wp.indexedfabricarray = self._build_array(self._local_selection_rw, self._LOCAL_MATRIX_ATTR) + # Pre-allocate reusable output buffers (world poses) self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) self._fabric_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) @@ -349,7 +343,7 @@ def _decompose_transforms( def _compute_fabric_indices(self) -> wp.array: # Assign to each prim an index - fabric_paths = self._index_selection_ro.GetPaths() + fabric_paths = self._trans_selection_ro.GetPaths() path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} # Look up the index for each prim observed by this view @@ -364,16 +358,30 @@ def _compute_fabric_indices(self) -> wp.array: return wp.array(fabric_indices, dtype=wp.int32).to(self._device) - def _ensure_fabric_indices_are_up_to_date(self, force_rebuild: bool = False) -> None: + def _ensure_fabric_indices_are_up_to_date(self, force_rebuild: bool = False) -> bool: """Build the view index to fabric index array from PrimSelection path ordering.""" # Rebuild indexing only when fabric topology has changed or whenever forced. - topology_changed = self._index_selection_ro.PrepareForReuse() + # TODO: consider what is it cheaper, store one selection with paths and call PrepareForReuse, + # Or each time call SelectPrims with the same paths and call PrepareForReuse? + + topology_changed = self._trans_selection_ro.PrepareForReuse() + + if topology_changed: + logger.warning("Fabric topology changed! Rebuilding fabric indices!") + if not (topology_changed or force_rebuild): - return + return False self._fabric_indices = self._compute_fabric_indices() + self._world_ifa_ro = self._build_array(self._trans_selection_ro, self._WORLD_MATRIX_ATTR) + self._local_ifa_ro = self._build_array(self._trans_selection_ro, self._LOCAL_MATRIX_ATTR) + self._world_ifa_rw = self._build_array(self._world_selection_rw, self._WORLD_MATRIX_ATTR) + self._local_ifa_rw = self._build_array(self._local_selection_rw, self._LOCAL_MATRIX_ATTR) + + return True + def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: """Create an indexed fabric array for a single attribute with the given access mode.""" import usdrt @@ -387,29 +395,35 @@ def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: fa = wp.fabricarray(selection, attr_name) return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) - def _get_world_ro_array(self) -> wp.indexedfabricarray: - import usdrt + def _build_array(self, selection: usdrt.PrimSelection, attribute_name: str) -> wp.indexedfabricarray: + fa = wp.fabricarray(selection, attribute_name) + return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) + def _get_world_ro_array(self) -> wp.indexedfabricarray: self._ensure_fabric_indices_are_up_to_date() - return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read) + return self._world_ifa_ro + # import usdrt + # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read) def _get_world_rw_array(self) -> wp.indexedfabricarray: - import usdrt - self._ensure_fabric_indices_are_up_to_date() - return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) + self._world_selection_rw.PrepareForReuse() # entire column as dirty + return self._world_ifa_rw + # import usdrt + # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) def _get_local_ro_array(self) -> wp.indexedfabricarray: - import usdrt - self._ensure_fabric_indices_are_up_to_date() - return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read) + return self._local_ifa_ro + # import usdrt + # return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read) def _get_local_rw_array(self) -> wp.indexedfabricarray: - import usdrt - self._ensure_fabric_indices_are_up_to_date() - return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) + self._local_selection_rw.PrepareForReuse() # entire column as dirty + return self._local_ifa_rw + # import usdrt + # return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.array: """Convert requested view indices to fabric indices. diff --git a/source/isaaclab/test/sim/test_views_xform_prim.py b/source/isaaclab/test/sim/test_views_xform_prim.py index 77dc46763d24..58e9271cd8fd 100644 --- a/source/isaaclab/test/sim/test_views_xform_prim.py +++ b/source/isaaclab/test/sim/test_views_xform_prim.py @@ -7,8 +7,9 @@ from isaaclab.app import AppLauncher -# launch omniverse app -simulation_app = AppLauncher(headless=True).app +# In order to test Fabric backend we need to enable cameras. This setting will enable +# Fabric Scene Delegate, allowing us to test Fabric operations, such as hierarchy updates. +simulation_app = AppLauncher(headless=True, enable_cameras=True).app """Rest everything follows.""" @@ -30,7 +31,7 @@ def test_setup_teardown(): """Create a blank new stage for each test.""" # Setup: Create a new stage - sim_utils.create_new_stage() + sim_utils.create_new_stage(create_fabric_stage=True) sim_utils.update_stage() # Yield for the test @@ -236,7 +237,7 @@ def test_xform_prim_view_initialization_empty_pattern(device): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - sim_utils.create_new_stage() + sim_utils.create_new_stage(create_fabric_stage=True) # Create view with pattern that matches nothing view = XformPrimView("/World/NonExistent_.*", device=device) From c9b68de5fa265643376af62ba0086c6a0a5ea322 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Sat, 4 Apr 2026 01:13:52 -0700 Subject: [PATCH 17/20] Update Fabric Backend and add debug methods --- .../sim/views/xform_fabric_backend.py | 160 +++++++++++++----- 1 file changed, 120 insertions(+), 40 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index c230d499b0a1..dace694cceb1 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -118,24 +118,24 @@ def __init__( # Persistent selections — one per (attribute x access-mode) combination. # PrepareForReuse() is called before each use to detect topology changes. - self._trans_selection_ro = self._stage.SelectPrims( - require_attrs=[world_matrix_ro, local_matrix_ro], - device=device, - want_paths=True, - ) - self._world_selection_rw = self._stage.SelectPrims(require_attrs=[world_matrix_rw], device=device) - self._local_selection_rw = self._stage.SelectPrims(require_attrs=[local_matrix_rw], device=device) + ro_ro = (world_matrix_ro, local_matrix_ro) + ro_rw = (world_matrix_ro, local_matrix_rw) + rw_ro = (world_matrix_rw, local_matrix_ro) + + self._trans_sel_ro = self._stage.SelectPrims(require_attrs=ro_ro, device=device, want_paths=True) + self._world_sel_rw = self._stage.SelectPrims(require_attrs=rw_ro, device=device, want_paths=True) + self._local_sel_rw = self._stage.SelectPrims(require_attrs=ro_rw, device=device, want_paths=True) # Build the view → fabric index array from PrimSelection path ordering. # Default view-index array [0, 1, ..., count-1] for "all prims". self._view_indices: wp.array = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) - self._fabric_indices: wp.array = self._compute_fabric_indices() + self._fabric_indices: wp.array = self._compute_fabric_indices(self._trans_sel_ro) # Cached indexed fabric arrays (rebuilt when topology changes). - self._world_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_selection_ro, self._WORLD_MATRIX_ATTR) - self._local_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_selection_ro, self._LOCAL_MATRIX_ATTR) - self._world_ifa_rw: wp.indexedfabricarray = self._build_array(self._world_selection_rw, self._WORLD_MATRIX_ATTR) - self._local_ifa_rw: wp.indexedfabricarray = self._build_array(self._local_selection_rw, self._LOCAL_MATRIX_ATTR) + self._world_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) + self._local_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_ATTR) + self._world_ifa_rw: wp.indexedfabricarray = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_ATTR) + self._local_ifa_rw: wp.indexedfabricarray = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_ATTR) # Pre-allocate reusable output buffers (world poses) self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) @@ -172,6 +172,7 @@ def prim_paths(self) -> list[str]: self._prim_paths_cache = [p.GetPath().pathString for p in self._prims] return self._prim_paths_cache + # ------------------------------------------------------------------ # Setters # ------------------------------------------------------------------ @@ -341,9 +342,9 @@ def _decompose_transforms( ) wp.synchronize() - def _compute_fabric_indices(self) -> wp.array: + def _compute_fabric_indices(self, selection: usdrt.PrimSelection) -> wp.array: # Assign to each prim an index - fabric_paths = self._trans_selection_ro.GetPaths() + fabric_paths = selection.GetPaths() path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} # Look up the index for each prim observed by this view @@ -358,14 +359,14 @@ def _compute_fabric_indices(self) -> wp.array: return wp.array(fabric_indices, dtype=wp.int32).to(self._device) - def _ensure_fabric_indices_are_up_to_date(self, force_rebuild: bool = False) -> bool: + def _ensure_fabric_indices_are_up_to_date(self, selection, force_rebuild: bool = False) -> bool: """Build the view index to fabric index array from PrimSelection path ordering.""" # Rebuild indexing only when fabric topology has changed or whenever forced. # TODO: consider what is it cheaper, store one selection with paths and call PrepareForReuse, # Or each time call SelectPrims with the same paths and call PrepareForReuse? - topology_changed = self._trans_selection_ro.PrepareForReuse() + topology_changed = selection.PrepareForReuse() if topology_changed: logger.warning("Fabric topology changed! Rebuilding fabric indices!") @@ -373,15 +374,13 @@ def _ensure_fabric_indices_are_up_to_date(self, force_rebuild: bool = False) -> if not (topology_changed or force_rebuild): return False - self._fabric_indices = self._compute_fabric_indices() - - self._world_ifa_ro = self._build_array(self._trans_selection_ro, self._WORLD_MATRIX_ATTR) - self._local_ifa_ro = self._build_array(self._trans_selection_ro, self._LOCAL_MATRIX_ATTR) - self._world_ifa_rw = self._build_array(self._world_selection_rw, self._WORLD_MATRIX_ATTR) - self._local_ifa_rw = self._build_array(self._local_selection_rw, self._LOCAL_MATRIX_ATTR) - + self._fabric_indices = self._compute_fabric_indices(selection) return True + def _build_array(self, selection: usdrt.PrimSelection, attribute_name: str) -> wp.indexedfabricarray: + fa = wp.fabricarray(selection, attribute_name) + return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) + def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: """Create an indexed fabric array for a single attribute with the given access mode.""" import usdrt @@ -391,39 +390,45 @@ def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: (usdrt.Sdf.ValueTypeNames.Matrix4d, attr_name, access), ], device=self._device, + want_paths=True, ) fa = wp.fabricarray(selection, attr_name) return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) - def _build_array(self, selection: usdrt.PrimSelection, attribute_name: str) -> wp.indexedfabricarray: - fa = wp.fabricarray(selection, attribute_name) - return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) - def _get_world_ro_array(self) -> wp.indexedfabricarray: - self._ensure_fabric_indices_are_up_to_date() - return self._world_ifa_ro # import usdrt # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read) - def _get_world_rw_array(self) -> wp.indexedfabricarray: - self._ensure_fabric_indices_are_up_to_date() - self._world_selection_rw.PrepareForReuse() # entire column as dirty - return self._world_ifa_rw - # import usdrt - # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) + if self._trans_sel_ro.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) + self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) + self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_ATTR) + return self._world_ifa_ro def _get_local_ro_array(self) -> wp.indexedfabricarray: - self._ensure_fabric_indices_are_up_to_date() - return self._local_ifa_ro # import usdrt # return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read) + if self._local_sel_rw.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) + self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_ATTR) + return self._local_ifa_rw + + def _get_world_rw_array(self) -> wp.indexedfabricarray: + # import usdrt + # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) + if self._world_sel_rw.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._world_sel_rw) + self._world_ifa_rw = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_ATTR) + return self._world_ifa_rw def _get_local_rw_array(self) -> wp.indexedfabricarray: - self._ensure_fabric_indices_are_up_to_date() - self._local_selection_rw.PrepareForReuse() # entire column as dirty - return self._local_ifa_rw # import usdrt # return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) + if self._local_sel_rw.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + self._local_ifa_rw = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_ATTR) + return self._local_ifa_rw def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.array: """Convert requested view indices to fabric indices. @@ -440,3 +445,78 @@ def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.a return self._view_indices indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) return wp.array(indices_list, dtype=wp.uint32).to(self._device) + + # ------------------------------------------------------------------ + # Debug helpers + # ------------------------------------------------------------------ + + def debug_read_fabric_matrices(self, indices: Sequence[int] | None = None) -> dict[str, list]: + """Read world and local matrices directly from Fabric via USDRT for debugging. + + Bypasses the Warp kernel path entirely and reads raw ``Gf.Matrix4d`` + values per-prim through the USDRT prim API. Useful for verifying that + Fabric contains the expected data after writes or population. + + Args: + indices: Prim indices to read. Defaults to all prims. + + Returns: + Dictionary with keys ``"prim_path"``, ``"world_matrix"``, and + ``"local_matrix"``, each a list with one entry per queried prim. + """ + import usdrt + + if indices is None: + indices = list(range(self.count)) + + result: dict[str, list] = {"prim_path": [], "world_matrix": [], "local_matrix": []} + for idx in indices: + prim_path = self.prim_paths[idx] + rt_prim = self._stage.GetPrimAtPath(usdrt.Sdf.Path(prim_path)) + + world_mat = None + local_mat = None + if rt_prim.IsValid(): + if rt_prim.HasAttribute(self._WORLD_MATRIX_ATTR): + world_mat = rt_prim.GetAttribute(self._WORLD_MATRIX_ATTR).Get() + if rt_prim.HasAttribute(self._LOCAL_MATRIX_ATTR): + local_mat = rt_prim.GetAttribute(self._LOCAL_MATRIX_ATTR).Get() + + result["prim_path"].append(prim_path) + result["world_matrix"].append(world_mat) + result["local_matrix"].append(local_mat) + + return result + + def debug_print_fabric_state(self, indices: Sequence[int] | None = None) -> None: + """Print Fabric matrix state, index mapping, and selection paths to stdout. + + Args: + indices: Prim indices to print. Defaults to all prims. + """ + if indices is None: + indices = list(range(self.count)) + + fabric_indices_np = self._fabric_indices.numpy() + fabric_paths = self._trans_sel_ro.GetPaths() + + print(f"[Fabric Debug] stage_id={self._stage_id} device={self._device} count={self.count}") + print(f"[Fabric Debug] SelectPrims returned {len(fabric_paths)} paths:") + for fi, fp in enumerate(fabric_paths): + print(f" fabric_idx={fi} path={fp}") + + print("[Fabric Debug] View → Fabric index mapping:") + for vi in range(self.count): + fi = int(fabric_indices_np[vi]) + print(f" view_idx={vi} fabric_idx={fi} path={self.prim_paths[vi]}") + + data = self.debug_read_fabric_matrices(indices) + print(f"[Fabric Debug] Matrices for requested indices {list(indices)}:") + for i, path in enumerate(data["prim_path"]): + vi = indices[i] + fi = int(fabric_indices_np[vi]) + wm = data["world_matrix"][i] + lm = data["local_matrix"][i] + print(f" [{vi}] {path} (fabric_idx={fi})") + print(f" world: {wm}") + print(f" local: {lm}") From ff5da88ebb6af8602ecc32543e7951245d8a6af8 Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Sat, 4 Apr 2026 09:45:02 -0700 Subject: [PATCH 18/20] Fix a bug with ro/wr accessor --- source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index dace694cceb1..df2e283bc26f 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -398,7 +398,6 @@ def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: def _get_world_ro_array(self) -> wp.indexedfabricarray: # import usdrt # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read) - if self._trans_sel_ro.PrepareForReuse(): self._fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) @@ -408,11 +407,11 @@ def _get_world_ro_array(self) -> wp.indexedfabricarray: def _get_local_ro_array(self) -> wp.indexedfabricarray: # import usdrt # return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read) - if self._local_sel_rw.PrepareForReuse(): - self._fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + if self._trans_sel_ro.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_ATTR) - return self._local_ifa_rw + return self._local_ifa_ro def _get_world_rw_array(self) -> wp.indexedfabricarray: # import usdrt From 8e704a2b9d2b7a64740fbd7496a7e5b2de92aace Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Sat, 4 Apr 2026 10:21:41 -0700 Subject: [PATCH 19/20] update repr --- .../sim/views/xform_fabric_backend.py | 43 +++++++++++-------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index df2e283bc26f..9c1ed4f537df 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -487,35 +487,42 @@ def debug_read_fabric_matrices(self, indices: Sequence[int] | None = None) -> di return result - def debug_print_fabric_state(self, indices: Sequence[int] | None = None) -> None: - """Print Fabric matrix state, index mapping, and selection paths to stdout. - - Args: - indices: Prim indices to print. Defaults to all prims. - """ - if indices is None: - indices = list(range(self.count)) + def __repr__(self) -> str: + lines: list[str] = [] + indices = list(range(self.count)) fabric_indices_np = self._fabric_indices.numpy() fabric_paths = self._trans_sel_ro.GetPaths() - print(f"[Fabric Debug] stage_id={self._stage_id} device={self._device} count={self.count}") - print(f"[Fabric Debug] SelectPrims returned {len(fabric_paths)} paths:") - for fi, fp in enumerate(fabric_paths): - print(f" fabric_idx={fi} path={fp}") + view_paths = self.prim_paths + + lines.append(f"[FabricBackend] stage_id={self._stage_id} device={self._device} count={self.count}") + import usdrt - print("[Fabric Debug] View → Fabric index mapping:") + lines.append(f"SelectPrims returned {len(fabric_paths)} paths:") + for fi, fp in enumerate(fabric_paths): + marker = " *" if fp in view_paths else " " + rt_prim = self._stage.GetPrimAtPath(usdrt.Sdf.Path(str(fp))) + wm = rt_prim.GetAttribute(self._WORLD_MATRIX_ATTR).Get() if rt_prim.HasAttribute(self._WORLD_MATRIX_ATTR) else None + lm = rt_prim.GetAttribute(self._LOCAL_MATRIX_ATTR).Get() if rt_prim.HasAttribute(self._LOCAL_MATRIX_ATTR) else None + lines.append(f"{marker} fabric_idx={fi} path={fp}") + lines.append(f" world: {wm}") + lines.append(f" local: {lm}") + + lines.append("View → Fabric index mapping:") for vi in range(self.count): fi = int(fabric_indices_np[vi]) - print(f" view_idx={vi} fabric_idx={fi} path={self.prim_paths[vi]}") + lines.append(f" view_idx={vi} fabric_idx={fi} path={self.prim_paths[vi]}") data = self.debug_read_fabric_matrices(indices) - print(f"[Fabric Debug] Matrices for requested indices {list(indices)}:") + lines.append("Matrices:") for i, path in enumerate(data["prim_path"]): vi = indices[i] fi = int(fabric_indices_np[vi]) wm = data["world_matrix"][i] lm = data["local_matrix"][i] - print(f" [{vi}] {path} (fabric_idx={fi})") - print(f" world: {wm}") - print(f" local: {lm}") + lines.append(f" [{vi}] {path} (fabric_idx={fi})") + lines.append(f" world: {wm}") + lines.append(f" local: {lm}") + + return "\n".join(lines) From a70c115880500976f67f4c6e218a12c0f5e30c0d Mon Sep 17 00:00:00 2001 From: Piotr Barejko Date: Sat, 4 Apr 2026 10:34:14 -0700 Subject: [PATCH 20/20] Rename and reformat --- .../sim/views/xform_fabric_backend.py | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py index 9c1ed4f537df..c1fb9dfc33c5 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -60,8 +60,8 @@ class FabricBackend: If topology of the fabric changes, then all fabric indices need to be rebuilt. """ - _WORLD_MATRIX_ATTR = "omni:fabric:worldMatrix" - _LOCAL_MATRIX_ATTR = "omni:fabric:localMatrix" + _WORLD_MATRIX_NAME = "omni:fabric:worldMatrix" + _LOCAL_MATRIX_NAME = "omni:fabric:localMatrix" _hierarchy_cache: dict[int, object] = {} _dirty_stages: set[int] = set() @@ -111,10 +111,10 @@ def __init__( matrix = usdrt.Sdf.ValueTypeNames.Matrix4d ro = usdrt.Usd.Access.Read rw = usdrt.Usd.Access.ReadWrite - world_matrix_ro = (matrix, self._WORLD_MATRIX_ATTR, ro) - local_matrix_ro = (matrix, self._LOCAL_MATRIX_ATTR, ro) - world_matrix_rw = (matrix, self._WORLD_MATRIX_ATTR, rw) - local_matrix_rw = (matrix, self._LOCAL_MATRIX_ATTR, rw) + world_matrix_ro = (matrix, self._WORLD_MATRIX_NAME, ro) + local_matrix_ro = (matrix, self._LOCAL_MATRIX_NAME, ro) + world_matrix_rw = (matrix, self._WORLD_MATRIX_NAME, rw) + local_matrix_rw = (matrix, self._LOCAL_MATRIX_NAME, rw) # Persistent selections — one per (attribute x access-mode) combination. # PrepareForReuse() is called before each use to detect topology changes. @@ -132,10 +132,10 @@ def __init__( self._fabric_indices: wp.array = self._compute_fabric_indices(self._trans_sel_ro) # Cached indexed fabric arrays (rebuilt when topology changes). - self._world_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) - self._local_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_ATTR) - self._world_ifa_rw: wp.indexedfabricarray = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_ATTR) - self._local_ifa_rw: wp.indexedfabricarray = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_ATTR) + self._world_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_NAME) + self._local_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_NAME) + self._world_ifa_rw: wp.indexedfabricarray = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_NAME) + self._local_ifa_rw: wp.indexedfabricarray = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_NAME) # Pre-allocate reusable output buffers (world poses) self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) @@ -172,7 +172,6 @@ def prim_paths(self) -> list[str]: self._prim_paths_cache = [p.GetPath().pathString for p in self._prims] return self._prim_paths_cache - # ------------------------------------------------------------------ # Setters # ------------------------------------------------------------------ @@ -397,36 +396,36 @@ def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: def _get_world_ro_array(self) -> wp.indexedfabricarray: # import usdrt - # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.Read) + # return self._select_indexed(self._WORLD_MATRIX_NAME, usdrt.Usd.Access.Read) if self._trans_sel_ro.PrepareForReuse(): self._fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) - self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) - self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_ATTR) + self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_NAME) + self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_NAME) return self._world_ifa_ro def _get_local_ro_array(self) -> wp.indexedfabricarray: # import usdrt - # return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.Read) + # return self._select_indexed(self._LOCAL_MATRIX_NAME, usdrt.Usd.Access.Read) if self._trans_sel_ro.PrepareForReuse(): self._fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) - self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_ATTR) - self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_ATTR) + self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_NAME) + self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_NAME) return self._local_ifa_ro def _get_world_rw_array(self) -> wp.indexedfabricarray: # import usdrt - # return self._select_indexed(self._WORLD_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) + # return self._select_indexed(self._WORLD_MATRIX_NAME, usdrt.Usd.Access.ReadWrite) if self._world_sel_rw.PrepareForReuse(): self._fabric_indices = self._compute_fabric_indices(self._world_sel_rw) - self._world_ifa_rw = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_ATTR) + self._world_ifa_rw = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_NAME) return self._world_ifa_rw def _get_local_rw_array(self) -> wp.indexedfabricarray: # import usdrt - # return self._select_indexed(self._LOCAL_MATRIX_ATTR, usdrt.Usd.Access.ReadWrite) + # return self._select_indexed(self._LOCAL_MATRIX_NAME, usdrt.Usd.Access.ReadWrite) if self._local_sel_rw.PrepareForReuse(): self._fabric_indices = self._compute_fabric_indices(self._local_sel_rw) - self._local_ifa_rw = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_ATTR) + self._local_ifa_rw = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_NAME) return self._local_ifa_rw def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.array: @@ -476,10 +475,10 @@ def debug_read_fabric_matrices(self, indices: Sequence[int] | None = None) -> di world_mat = None local_mat = None if rt_prim.IsValid(): - if rt_prim.HasAttribute(self._WORLD_MATRIX_ATTR): - world_mat = rt_prim.GetAttribute(self._WORLD_MATRIX_ATTR).Get() - if rt_prim.HasAttribute(self._LOCAL_MATRIX_ATTR): - local_mat = rt_prim.GetAttribute(self._LOCAL_MATRIX_ATTR).Get() + if rt_prim.HasAttribute(self._WORLD_MATRIX_NAME): + world_mat = rt_prim.GetAttribute(self._WORLD_MATRIX_NAME).Get() + if rt_prim.HasAttribute(self._LOCAL_MATRIX_NAME): + local_mat = rt_prim.GetAttribute(self._LOCAL_MATRIX_NAME).Get() result["prim_path"].append(prim_path) result["world_matrix"].append(world_mat) @@ -503,8 +502,16 @@ def __repr__(self) -> str: for fi, fp in enumerate(fabric_paths): marker = " *" if fp in view_paths else " " rt_prim = self._stage.GetPrimAtPath(usdrt.Sdf.Path(str(fp))) - wm = rt_prim.GetAttribute(self._WORLD_MATRIX_ATTR).Get() if rt_prim.HasAttribute(self._WORLD_MATRIX_ATTR) else None - lm = rt_prim.GetAttribute(self._LOCAL_MATRIX_ATTR).Get() if rt_prim.HasAttribute(self._LOCAL_MATRIX_ATTR) else None + wm = ( + rt_prim.GetAttribute(self._WORLD_MATRIX_NAME).Get() + if rt_prim.HasAttribute(self._WORLD_MATRIX_NAME) + else None + ) + lm = ( + rt_prim.GetAttribute(self._LOCAL_MATRIX_NAME).Get() + if rt_prim.HasAttribute(self._LOCAL_MATRIX_NAME) + else None + ) lines.append(f"{marker} fabric_idx={fi} path={fp}") lines.append(f" world: {wm}") lines.append(f" local: {lm}")