diff --git a/scripts/benchmarks/benchmark_xform_prim_view.py b/scripts/benchmarks/benchmark_xform_prim_view.py index e76796e20271..5a9b5884fa64 100644 --- a/scripts/benchmarks/benchmark_xform_prim_view.py +++ b/scripts/benchmarks/benchmark_xform_prim_view.py @@ -54,6 +54,7 @@ AppLauncher.add_app_launcher_args(parser) args_cli = parser.parse_args() +args_cli.enable_cameras = True # launch omniverse app app_launcher = AppLauncher(args_cli) @@ -105,7 +106,9 @@ def benchmark_xform_prim_view( # noqa: C901 # Setup scene print(" Setting up scene") # Clear stage - sim_utils.create_new_stage() + + use_fabric: bool = "fabric" in api.lower() + sim_utils.create_new_stage(create_fabric_stage=use_fabric) # Create simulation context start_time = time.perf_counter() sim_cfg = sim_utils.SimulationCfg( diff --git a/source/isaaclab/isaaclab/sensors/camera/camera.py b/source/isaaclab/isaaclab/sensors/camera/camera.py index 70ccc6c14dd8..f03c399500f2 100644 --- a/source/isaaclab/isaaclab/sensors/camera/camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/camera.py @@ -450,7 +450,9 @@ def _initialize_impl(self): super()._initialize_impl() # Create a view for the sensor with Fabric enabled for fast pose queries, otherwise position will be stale. self._view = XformPrimView( - self.cfg.prim_path, device=self._device, stage=self.stage, sync_usd_on_fabric_write=True + self.cfg.prim_path, + device=self._device, + stage=self.stage, ) # Check that sizes are correct if self._view.count != self._num_envs: diff --git a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py index e6ae56adaa53..4e72cfa615bf 100644 --- a/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/tiled_camera.py @@ -170,8 +170,11 @@ def _initialize_impl(self): # the view keeps references to the prims located in the stage self.renderer.prepare_stage(self.stage, self._num_envs) - # Create a view for the sensor - self._view = XformPrimView(self.cfg.prim_path, device=self._device, stage=self.stage) + self._view = XformPrimView( + self.cfg.prim_path, + device=self._device, + stage=self.stage, + ) # Check that sizes are correct if self._view.count != self._num_envs: raise RuntimeError( diff --git a/source/isaaclab/isaaclab/sim/utils/stage.py b/source/isaaclab/isaaclab/sim/utils/stage.py index eb2b54d57cbd..723667a257d7 100644 --- a/source/isaaclab/isaaclab/sim/utils/stage.py +++ b/source/isaaclab/isaaclab/sim/utils/stage.py @@ -130,28 +130,65 @@ def _modify_path(asset_path: str) -> str: pass -def create_new_stage() -> Usd.Stage: +def create_new_stage(create_fabric_stage: bool = False) -> Usd.Stage: """Create a new in-memory USD stage. - Creates a new stage using pure USD (``Usd.Stage.CreateInMemory()``). + When ``create_fabric_stage`` is False (default), creates a pure USD stage + via ``Usd.Stage.CreateInMemory()``. + + When ``create_fabric_stage`` is True, creates the stage via + ``usdrt.Usd.Stage.CreateInMemory()`` which provides a paired Fabric store + alongside the USD stage. USD notice handling is enabled so that subsequent + prim creation on the USD stage automatically propagates into Fabric. Both + the USD stage and the USDRT stage handle are kept alive in a thread-local + context to prevent garbage collection from releasing the Fabric resources. If Kit is running and Kit extensions need to discover this stage (e.g. PhysX, ``isaacsim.core.prims.Articulation``), call :func:`attach_stage_to_usd_context` after scene setup. + Args: + create_fabric_stage: If True, create the stage through USDRT with a + backing Fabric store and enable USD-to-Fabric notice-driven sync. + Defaults to False. + Returns: - Usd.Stage: The created USD stage. + The created USD stage. Example: >>> import isaaclab.sim as sim_utils >>> >>> sim_utils.create_new_stage() - Usd.Stage.Open(rootLayer=Sdf.Find('anon:0x7fba6c04f840:World7.usd'), - sessionLayer=Sdf.Find('anon:0x7fba6c01c5c0:World7-session.usda'), - pathResolverContext=) + Usd.Stage.Open(rootLayer=Sdf.Find('anon:0x...'), ...) + >>> sim_utils.create_new_stage(create_fabric_stage=True) + Usd.Stage.Open(rootLayer=Sdf.Find('anon:0x...'), ...) """ + + if create_fabric_stage: + import uuid + + import usdrt + + rt_stage = usdrt.Usd.Stage.CreateInMemory(f"World_{uuid.uuid4().hex[:8]}.usda") + stage = UsdUtils.StageCache.Get().Find(Usd.StageCache.Id.FromLongInt(rt_stage.GetStageId())) + + # Storing stage in the context is required, so both stages aren't going to be garbage collected. + _context.stage = stage + _context.rt_stage = rt_stage + + stage_id = rt_stage.GetStageIdAsStageId() + fabric_id = rt_stage.GetFabricId() + srw_id = rt_stage.GetStageReaderWriterId() + + pop = usdrt.population.IUtils() + pop.set_enable_usd_notice_handling(stage_id, fabric_id, True) + pop.populate_from_usd(srw_id, stage_id, usdrt.Sdf.Path("/"), 0) + pop.apply_pending_usd_updates(stage_id, srw_id, 0) + return stage + stage: Usd.Stage = Usd.Stage.CreateInMemory() _context.stage = stage + _context.rt_stage = None UsdUtils.StageCache.Get().Insert(stage) return stage @@ -400,6 +437,7 @@ def close_stage() -> bool: stage_cache = UsdUtils.StageCache.Get() stage_cache.Clear() _context.stage = None + _context.rt_stage = None return True diff --git a/source/isaaclab/isaaclab/sim/views/__init__.pyi b/source/isaaclab/isaaclab/sim/views/__init__.pyi index a666958e4387..76864b225a6e 100644 --- a/source/isaaclab/isaaclab/sim/views/__init__.pyi +++ b/source/isaaclab/isaaclab/sim/views/__init__.pyi @@ -4,7 +4,13 @@ # SPDX-License-Identifier: BSD-3-Clause __all__ = [ + "FabricBackend", + "UsdBackend", + "XformBackend", "XformPrimView", ] +from .xform_backend import XformBackend +from .xform_fabric_backend import FabricBackend from .xform_prim_view import XformPrimView +from .xform_usd_backend import UsdBackend diff --git a/source/isaaclab/isaaclab/sim/views/xform_backend.py b/source/isaaclab/isaaclab/sim/views/xform_backend.py new file mode 100644 index 000000000000..e173d3764647 --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_backend.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Protocol + +import torch + + +class XformBackend(Protocol): + """Protocol defining the interface for :class:`XformPrimView` transform backends. + + Implementations provide read/write access to prim transforms through either + the USD or Fabric data path. :class:`XformPrimView` delegates all transform + operations to a *primary* backend and optionally replicates writes to one or + more *sync* backends. + """ + + def initialize(self) -> None: + """Perform any deferred initialisation required by the backend.""" + ... + + def set_world_poses( + self, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set world-space poses for the managed prims. + + Args: + positions: World-space positions, shape ``(M, 3)`` [m]. + orientations: World-space quaternions ``(x, y, z, w)``, shape ``(M, 4)``. + indices: Subset of prim indices to update. ``None`` means all. + """ + ... + + def get_world_poses( + self, + indices: Sequence[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Return world-space ``(positions, orientations)`` for the managed prims. + + Args: + indices: Subset of prim indices to query. ``None`` means all. + + Returns: + ``(positions, orientations)`` with shapes ``(M, 3)`` and ``(M, 4)``. + """ + ... + + def set_local_poses( + self, + translations: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set local-space poses (relative to each prim's parent). + + Args: + translations: Local-space translations, shape ``(M, 3)`` [m]. + orientations: Local-space quaternions ``(x, y, z, w)``, shape ``(M, 4)``. + indices: Subset of prim indices to update. ``None`` means all. + """ + ... + + def get_local_poses( + self, + indices: Sequence[int] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Return local-space ``(translations, orientations)`` for the managed prims. + + Args: + indices: Subset of prim indices to query. ``None`` means all. + + Returns: + ``(translations, orientations)`` with shapes ``(M, 3)`` and ``(M, 4)``. + """ + ... + + def set_scales( + self, + scales: torch.Tensor, + indices: Sequence[int] | None = None, + ) -> None: + """Set scales for the managed prims. + + Args: + scales: Scales, shape ``(M, 3)``. + indices: Subset of prim indices to update. ``None`` means all. + """ + ... + + def get_scales( + self, + indices: Sequence[int] | None = None, + ) -> torch.Tensor: + """Return scales for the managed prims. + + Args: + indices: Subset of prim indices to query. ``None`` means all. + + Returns: + Scales tensor of shape ``(M, 3)``. + """ + ... diff --git a/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py new file mode 100644 index 000000000000..c1fb9dfc33c5 --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_fabric_backend.py @@ -0,0 +1,535 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import logging +from collections.abc import Sequence + +import torch +import warp as wp + +from pxr import Usd + +import isaaclab.sim as sim_utils +from isaaclab.utils.warp import fabric as fabric_utils + +logger = logging.getLogger(__name__) + + +class FabricBackend: + """Fabric-based transform backend for :class:`XformPrimView`. + + Uses NVIDIA's Fabric API with Warp GPU kernels for high-performance batch + transform operations. + + Selected primitives based on attributes such as local and world matrices. + (all prims in selection, could be in different buckets): + ┌─────────┬───────────────┬───────────┐ + │ Fab Idx │ Prim Path │ attribute │ + ├─────────┼───────────────┼───────────┤ + │ 0 │ /World/Light │ [ ... ] │ + │ 1 │ /World/Cam_2 │ [ ... ] │ + │ 2 │ /World/Ground │ [ ... ] │ + │ 3 │ /World/Cam_0 │ [ ... ] │ + │ 4 │ /World/Table │ [ ... ] │ + │ 5 │ /World/Cam_1 │ [ ... ] │ + │ 6 │ /World/Robot │ [ ... ] │ + └─────────┴───────────────┴───────────┘ + + Example view of 3 prims, order of the paths defines order of indices + ┌──────────┬──────────────┐ + │ View Idx │ Prim Path │ + ├──────────┼──────────────┤ + │ 0 │ /World/Cam_0 │ + │ 1 │ /World/Cam_1 │ + │ 2 │ /World/Cam_2 │ + └──────────┴──────────────┘ + + Mapping from view indices to fabric indices happens through a fabric index array: + ┌──────────┬─────────┐ + │ View Idx │ Fab Idx │ + ├──────────┼─────────┤ + │ 0 │ 3 │ + │ 1 │ 5 │ + │ 2 │ 1 │ + └──────────┴─────────┘ + + If topology of the fabric changes, then all fabric indices need to be rebuilt. + """ + + _WORLD_MATRIX_NAME = "omni:fabric:worldMatrix" + _LOCAL_MATRIX_NAME = "omni:fabric:localMatrix" + + _hierarchy_cache: dict[int, object] = {} + _dirty_stages: set[int] = set() + + def __init__( + self, + prims: list[Usd.Prim], + device: str, + ): + self._prims = prims + self._device = device + + # Resolve the Fabric device string (SelectPrims only supports cuda:0) + if self._device.startswith("cuda"): + if self._device == "cuda": + logger.info("Fabric device is not specified, defaulting to 'cuda:0'.") + elif self._device != "cuda:0": + logger.debug( + "SelectPrims only supports cuda:0. Using cuda:0 even though simulation device is %s.", + self._device, + ) + device = "cuda:0" + else: + device = self._device + self._device = device + + import usdrt + from usdrt import Rt # noqa: F401 — imported for side-effects + + self._stage_id = sim_utils.get_current_stage_id() + self._stage = usdrt.Usd.Stage.Attach(self._stage_id) + + # Reuse (or create) a hierarchy handle for this stage. + if self._stage_id not in FabricBackend._hierarchy_cache: + self._stage.SynchronizeToFabric() + hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( + self._stage.GetFabricId(), + self._stage.GetStageIdAsStageId(), + ) + hierarchy.update_world_xforms() + hierarchy.track_local_xform_changes(True) + hierarchy.track_world_xform_changes(True) + FabricBackend._hierarchy_cache[self._stage_id] = hierarchy + + self._fabric_hierarchy = FabricBackend._hierarchy_cache[self._stage_id] + + matrix = usdrt.Sdf.ValueTypeNames.Matrix4d + ro = usdrt.Usd.Access.Read + rw = usdrt.Usd.Access.ReadWrite + world_matrix_ro = (matrix, self._WORLD_MATRIX_NAME, ro) + local_matrix_ro = (matrix, self._LOCAL_MATRIX_NAME, ro) + world_matrix_rw = (matrix, self._WORLD_MATRIX_NAME, rw) + local_matrix_rw = (matrix, self._LOCAL_MATRIX_NAME, rw) + + # Persistent selections — one per (attribute x access-mode) combination. + # PrepareForReuse() is called before each use to detect topology changes. + ro_ro = (world_matrix_ro, local_matrix_ro) + ro_rw = (world_matrix_ro, local_matrix_rw) + rw_ro = (world_matrix_rw, local_matrix_ro) + + self._trans_sel_ro = self._stage.SelectPrims(require_attrs=ro_ro, device=device, want_paths=True) + self._world_sel_rw = self._stage.SelectPrims(require_attrs=rw_ro, device=device, want_paths=True) + self._local_sel_rw = self._stage.SelectPrims(require_attrs=ro_rw, device=device, want_paths=True) + + # Build the view → fabric index array from PrimSelection path ordering. + # Default view-index array [0, 1, ..., count-1] for "all prims". + self._view_indices: wp.array = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) + self._fabric_indices: wp.array = self._compute_fabric_indices(self._trans_sel_ro) + + # Cached indexed fabric arrays (rebuilt when topology changes). + self._world_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_NAME) + self._local_ifa_ro: wp.indexedfabricarray = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_NAME) + self._world_ifa_rw: wp.indexedfabricarray = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_NAME) + self._local_ifa_rw: wp.indexedfabricarray = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_NAME) + + # Pre-allocate reusable output buffers (world poses) + self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) + self._fabric_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) + self._fabric_scales_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) + + self._fabric_positions_buffer = wp.from_torch(self._fabric_positions_torch, dtype=wp.float32) + self._fabric_orientations_buffer = wp.from_torch(self._fabric_orientations_torch, dtype=wp.float32) + self._fabric_scales_buffer = wp.from_torch(self._fabric_scales_torch, dtype=wp.float32) + + # Pre-allocate reusable output buffers (local poses) + self._fabric_local_translations_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) + self._fabric_local_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) + + self._fabric_local_translations_buffer = wp.from_torch(self._fabric_local_translations_torch, dtype=wp.float32) + self._fabric_local_orientations_buffer = wp.from_torch(self._fabric_local_orientations_torch, dtype=wp.float32) + + # Dummy buffer for unused kernel outputs (always empty) + self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def count(self) -> int: + """Number of prims managed by this backend.""" + return len(self._prims) + + @property + def prim_paths(self) -> list[str]: + """Prim path strings (lazily cached).""" + if not hasattr(self, "_prim_paths_cache"): + self._prim_paths_cache = [p.GetPath().pathString for p in self._prims] + return self._prim_paths_cache + + # ------------------------------------------------------------------ + # Setters + # ------------------------------------------------------------------ + + def set_world_poses( + self, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Write world poses to Fabric ``omni:fabric:worldMatrix`` via a Warp kernel.""" + + # if local transforms were set, we need to update the world transforms + if self._stage_id in FabricBackend._dirty_stages: + self._fabric_hierarchy.update_world_xforms() + FabricBackend._dirty_stages.discard(self._stage_id) + + fabric_indices = self._convert_view_to_fabric_indices(indices) + self._compose_transforms( + self._get_world_rw_array(), fabric_indices, positions=positions, orientations=orientations + ) + + def set_local_poses( + self, + translations: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Write local poses to Fabric ``omni:fabric:localMatrix`` via a Warp kernel. + + After composing the local matrix the method re-registers it through + :pyobj:`IFabricHierarchy` and marks world transforms as dirty so that a + subsequent read will propagate the change. + """ + fabric_indices = self._convert_view_to_fabric_indices(indices) + self._compose_transforms( + self._get_local_rw_array(), fabric_indices, positions=translations, orientations=orientations + ) + + FabricBackend._dirty_stages.add(self._stage_id) + + def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) -> None: + """Write scales into the Fabric world matrix via a Warp kernel.""" + fabric_indices = self._convert_view_to_fabric_indices(indices) + self._compose_transforms(self._get_world_rw_array(), fabric_indices, scales=scales) + + # ------------------------------------------------------------------ + # Getters + # ------------------------------------------------------------------ + + def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Read world poses from Fabric and decompose via a Warp kernel.""" + if self._stage_id in FabricBackend._dirty_stages: + self._fabric_hierarchy.update_world_xforms() + FabricBackend._dirty_stages.discard(self._stage_id) + + fabric_indices = self._convert_view_to_fabric_indices(indices) + count = fabric_indices.shape[0] + dummy = self._fabric_dummy_buffer + + use_cached_buffers = indices is None or indices == slice(None) + if use_cached_buffers: + positions_wp = self._fabric_positions_buffer + orientations_wp = self._fabric_orientations_buffer + else: + positions_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) + orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) + + self._decompose_transforms(self._get_world_ro_array(), fabric_indices, positions_wp, orientations_wp, dummy) + + if use_cached_buffers: + return self._fabric_positions_torch, self._fabric_orientations_torch + return wp.to_torch(positions_wp), wp.to_torch(orientations_wp) + + def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Read local poses from Fabric and decompose via a Warp kernel.""" + fabric_indices = self._convert_view_to_fabric_indices(indices) + count = fabric_indices.shape[0] + dummy = self._fabric_dummy_buffer + + use_cached_buffers = indices is None or indices == slice(None) + if use_cached_buffers: + translations_wp = self._fabric_local_translations_buffer + orientations_wp = self._fabric_local_orientations_buffer + else: + translations_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) + orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) + + self._decompose_transforms(self._get_local_ro_array(), fabric_indices, translations_wp, orientations_wp, dummy) + + if use_cached_buffers: + return self._fabric_local_translations_torch, self._fabric_local_orientations_torch + return wp.to_torch(translations_wp), wp.to_torch(orientations_wp) + + def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: + """Read scales from Fabric world matrices and extract via a Warp kernel.""" + fabric_indices = self._convert_view_to_fabric_indices(indices) + count = fabric_indices.shape[0] + dummy = self._fabric_dummy_buffer + + use_cached_buffers = indices is None or indices == slice(None) + if use_cached_buffers: + scales_wp = self._fabric_scales_buffer + else: + scales_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) + + self._decompose_transforms(self._get_world_ro_array(), fabric_indices, dummy, dummy, scales_wp) + + if use_cached_buffers: + return self._fabric_scales_torch + return wp.to_torch(scales_wp) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _compose_transforms( + self, + matrices: wp.indexedfabricarray, + indices_wp: wp.array, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + scales: torch.Tensor | None = None, + ) -> None: + """Launch the compose kernel to write transform components into Fabric matrices. + + Converts non-``None`` torch tensors to Warp arrays and substitutes a + pre-allocated zero-length dummy for omitted components so the kernel + leaves existing values untouched. + """ + dummy = self._fabric_dummy_buffer + positions_wp = wp.from_torch(positions) if positions is not None else dummy + orientations_wp = wp.from_torch(orientations) if orientations is not None else dummy + scales_wp = wp.from_torch(scales) if scales is not None else dummy + + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[ + matrices, + positions_wp, + orientations_wp, + scales_wp, + False, + False, + False, + indices_wp, + ], + device=self._device, + ) + wp.synchronize() + + def _decompose_transforms( + self, + matrices: wp.indexedfabricarray, + indices_wp: wp.array, + positions_wp: wp.array, + orientations_wp: wp.array, + scales_wp: wp.array, + ) -> None: + """Launch the decompose kernel to read transform components from Fabric matrices.""" + wp.launch( + kernel=fabric_utils.decompose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[matrices, positions_wp, orientations_wp, scales_wp, indices_wp], + device=self._device, + ) + wp.synchronize() + + def _compute_fabric_indices(self, selection: usdrt.PrimSelection) -> wp.array: + # Assign to each prim an index + fabric_paths = selection.GetPaths() + path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} + + # Look up the index for each prim observed by this view + fabric_indices: list[int] = [] + for prim_path in self.prim_paths: + fabric_idx = path_to_fabric_idx.get(prim_path) + if fabric_idx is None: + raise RuntimeError( + f"Prim '{prim_path}' not found in Fabric selection. Ensure the hierarchy has been populated." + ) + fabric_indices.append(fabric_idx) + + return wp.array(fabric_indices, dtype=wp.int32).to(self._device) + + def _ensure_fabric_indices_are_up_to_date(self, selection, force_rebuild: bool = False) -> bool: + """Build the view index to fabric index array from PrimSelection path ordering.""" + + # Rebuild indexing only when fabric topology has changed or whenever forced. + # TODO: consider what is it cheaper, store one selection with paths and call PrepareForReuse, + # Or each time call SelectPrims with the same paths and call PrepareForReuse? + + topology_changed = selection.PrepareForReuse() + + if topology_changed: + logger.warning("Fabric topology changed! Rebuilding fabric indices!") + + if not (topology_changed or force_rebuild): + return False + + self._fabric_indices = self._compute_fabric_indices(selection) + return True + + def _build_array(self, selection: usdrt.PrimSelection, attribute_name: str) -> wp.indexedfabricarray: + fa = wp.fabricarray(selection, attribute_name) + return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) + + def _select_indexed(self, attr_name: str, access) -> wp.indexedfabricarray: + """Create an indexed fabric array for a single attribute with the given access mode.""" + import usdrt + + selection = self._stage.SelectPrims( + require_attrs=[ + (usdrt.Sdf.ValueTypeNames.Matrix4d, attr_name, access), + ], + device=self._device, + want_paths=True, + ) + fa = wp.fabricarray(selection, attr_name) + return wp.indexedfabricarray(fa=fa, indices=self._fabric_indices) + + def _get_world_ro_array(self) -> wp.indexedfabricarray: + # import usdrt + # return self._select_indexed(self._WORLD_MATRIX_NAME, usdrt.Usd.Access.Read) + if self._trans_sel_ro.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) + self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_NAME) + self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_NAME) + return self._world_ifa_ro + + def _get_local_ro_array(self) -> wp.indexedfabricarray: + # import usdrt + # return self._select_indexed(self._LOCAL_MATRIX_NAME, usdrt.Usd.Access.Read) + if self._trans_sel_ro.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) + self._world_ifa_ro = self._build_array(self._trans_sel_ro, self._WORLD_MATRIX_NAME) + self._local_ifa_ro = self._build_array(self._trans_sel_ro, self._LOCAL_MATRIX_NAME) + return self._local_ifa_ro + + def _get_world_rw_array(self) -> wp.indexedfabricarray: + # import usdrt + # return self._select_indexed(self._WORLD_MATRIX_NAME, usdrt.Usd.Access.ReadWrite) + if self._world_sel_rw.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._world_sel_rw) + self._world_ifa_rw = self._build_array(self._world_sel_rw, self._WORLD_MATRIX_NAME) + return self._world_ifa_rw + + def _get_local_rw_array(self) -> wp.indexedfabricarray: + # import usdrt + # return self._select_indexed(self._LOCAL_MATRIX_NAME, usdrt.Usd.Access.ReadWrite) + if self._local_sel_rw.PrepareForReuse(): + self._fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + self._local_ifa_rw = self._build_array(self._local_sel_rw, self._LOCAL_MATRIX_NAME) + return self._local_ifa_rw + + def _convert_view_to_fabric_indices(self, indices: Sequence[int] | None) -> wp.array: + """Convert requested view indices to fabric indices. + + Args: + indices: Requested path indices. If None, then all indices are used. + + Returns: + A warp array of fabric indices. + """ + if indices is None or indices == slice(None): + if self._view_indices is None: + raise RuntimeError("Fabric indices are not initialized.") + return self._view_indices + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + return wp.array(indices_list, dtype=wp.uint32).to(self._device) + + # ------------------------------------------------------------------ + # Debug helpers + # ------------------------------------------------------------------ + + def debug_read_fabric_matrices(self, indices: Sequence[int] | None = None) -> dict[str, list]: + """Read world and local matrices directly from Fabric via USDRT for debugging. + + Bypasses the Warp kernel path entirely and reads raw ``Gf.Matrix4d`` + values per-prim through the USDRT prim API. Useful for verifying that + Fabric contains the expected data after writes or population. + + Args: + indices: Prim indices to read. Defaults to all prims. + + Returns: + Dictionary with keys ``"prim_path"``, ``"world_matrix"``, and + ``"local_matrix"``, each a list with one entry per queried prim. + """ + import usdrt + + if indices is None: + indices = list(range(self.count)) + + result: dict[str, list] = {"prim_path": [], "world_matrix": [], "local_matrix": []} + for idx in indices: + prim_path = self.prim_paths[idx] + rt_prim = self._stage.GetPrimAtPath(usdrt.Sdf.Path(prim_path)) + + world_mat = None + local_mat = None + if rt_prim.IsValid(): + if rt_prim.HasAttribute(self._WORLD_MATRIX_NAME): + world_mat = rt_prim.GetAttribute(self._WORLD_MATRIX_NAME).Get() + if rt_prim.HasAttribute(self._LOCAL_MATRIX_NAME): + local_mat = rt_prim.GetAttribute(self._LOCAL_MATRIX_NAME).Get() + + result["prim_path"].append(prim_path) + result["world_matrix"].append(world_mat) + result["local_matrix"].append(local_mat) + + return result + + def __repr__(self) -> str: + lines: list[str] = [] + indices = list(range(self.count)) + + fabric_indices_np = self._fabric_indices.numpy() + fabric_paths = self._trans_sel_ro.GetPaths() + + view_paths = self.prim_paths + + lines.append(f"[FabricBackend] stage_id={self._stage_id} device={self._device} count={self.count}") + import usdrt + + lines.append(f"SelectPrims returned {len(fabric_paths)} paths:") + for fi, fp in enumerate(fabric_paths): + marker = " *" if fp in view_paths else " " + rt_prim = self._stage.GetPrimAtPath(usdrt.Sdf.Path(str(fp))) + wm = ( + rt_prim.GetAttribute(self._WORLD_MATRIX_NAME).Get() + if rt_prim.HasAttribute(self._WORLD_MATRIX_NAME) + else None + ) + lm = ( + rt_prim.GetAttribute(self._LOCAL_MATRIX_NAME).Get() + if rt_prim.HasAttribute(self._LOCAL_MATRIX_NAME) + else None + ) + lines.append(f"{marker} fabric_idx={fi} path={fp}") + lines.append(f" world: {wm}") + lines.append(f" local: {lm}") + + lines.append("View → Fabric index mapping:") + for vi in range(self.count): + fi = int(fabric_indices_np[vi]) + lines.append(f" view_idx={vi} fabric_idx={fi} path={self.prim_paths[vi]}") + + data = self.debug_read_fabric_matrices(indices) + lines.append("Matrices:") + for i, path in enumerate(data["prim_path"]): + vi = indices[i] + fi = int(fabric_indices_np[vi]) + wm = data["world_matrix"][i] + lm = data["local_matrix"][i] + lines.append(f" [{vi}] {path} (fabric_idx={fi})") + lines.append(f" world: {wm}") + lines.append(f" local: {lm}") + + return "\n".join(lines) diff --git a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py index 211994a7226b..31296a81abb3 100644 --- a/source/isaaclab/isaaclab/sim/views/xform_prim_view.py +++ b/source/isaaclab/isaaclab/sim/views/xform_prim_view.py @@ -8,15 +8,16 @@ import logging from collections.abc import Sequence -import numpy as np import torch -import warp as wp -from pxr import Gf, Sdf, Usd, UsdGeom, Vt +from pxr import Sdf, Usd, UsdGeom import isaaclab.sim as sim_utils from isaaclab.app.settings_manager import SettingsManager -from isaaclab.utils.warp import fabric as fabric_utils + +from .xform_backend import XformBackend +from .xform_fabric_backend import FabricBackend +from .xform_usd_backend import UsdBackend logger = logging.getLogger(__name__) @@ -117,7 +118,6 @@ def __init__( ValueError: If any matched prim is not Xformable or doesn't have standardized transform operations (translate, orient, scale in that order). """ - # Store configuration self._prim_path = prim_path self._device = device @@ -125,7 +125,6 @@ def __init__( stage = sim_utils.get_current_stage() if stage is None else stage self._prims: list[Usd.Prim] = sim_utils.find_matching_prims(prim_path, stage=stage) - # Validate all prims have standard xform operations if validate_xform_ops: for prim in self._prims: sim_utils.standardize_xform_ops(prim) @@ -136,56 +135,34 @@ def __init__( " Use sim_utils.standardize_xform_ops() to prepare the prim." ) - # Determine if Fabric is supported on the device + # Determine whether Fabric is available settings = SettingsManager.instance() - self._use_fabric = bool(settings.get("/physics/fabricEnabled", False)) + use_fabric = bool(settings.get("/physics/fabricEnabled", False)) - # Check for unsupported Fabric + CPU combination - if self._use_fabric and self._device == "cpu": - logger.warning( - "Fabric mode with Warp fabric-array operations is not supported on CPU devices. " - "While Fabric itself can run on both CPU and GPU, our batch Warp kernels for " - "fabric-array operations require CUDA and are not reliable on the CPU backend. " - "To ensure stability, Fabric is being disabled and execution will fall back " - "to standard USD operations on the CPU. This may impact performance." - ) - self._use_fabric = False - - # Check for unsupported Fabric + non-primary CUDA device combination. - # USDRT SelectPrims and Warp fabric arrays only support cuda:0 internally. - # When running on cuda:1 or higher, SelectPrims raises a C++ error regardless of - # the device argument, because USDRT uses the active CUDA context (which is cuda:1). - if self._use_fabric and self._device not in ("cuda", "cuda:0"): + if use_fabric and self._device not in ("cpu", "cuda", "cuda:0"): logger.warning( f"Fabric mode is not supported on device '{self._device}'. " "USDRT SelectPrims and Warp fabric arrays only support cuda:0. " "Falling back to standard USD operations. This may impact performance." ) - self._use_fabric = False + use_fabric = False - # Create indices buffer - # Since we iterate over the indices, we need to use range instead of torch tensor + # Index list used by visibility (USD-only) self._ALL_INDICES = list(range(len(self._prims))) - # Some prims (e.g., Cameras) require USD-authored transforms for rendering. - # When enabled, mirror Fabric pose writes to USD for those prims. - self._sync_usd_on_fabric_write = sync_usd_on_fabric_write - - # Fabric batch infrastructure (initialized lazily on first use) - self._fabric_initialized = False - self._fabric_usd_sync_done = False - self._fabric_selection = None - self._fabric_to_view: wp.array | None = None - self._view_to_fabric: wp.array | None = None - self._default_view_indices: wp.array | None = None - self._fabric_hierarchy = None - # Create a valid USD attribute name: namespace:name - # Use "isaaclab" namespace to identify our custom attributes - self._view_index_attr = f"isaaclab:view_index:{abs(hash(self))}" + # ---- Create backends ------------------------------------------------ + if use_fabric: + self._backend: XformBackend = FabricBackend(self._prims, self._device) + self._sync_backends: list[XformBackend] = ( + [UsdBackend(self._prims, self._device)] if sync_usd_on_fabric_write else [] + ) + else: + self._backend = UsdBackend(self._prims, self._device) + self._sync_backends = [] - """ - Properties. - """ + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ @property def count(self) -> int: @@ -214,16 +191,13 @@ def prim_paths(self) -> list[str]: to the USD prim objects without the conversion overhead. This property is mainly useful for logging, debugging, or when string paths are explicitly required. """ - # we cache it the first time it is accessed. - # we don't compute it in constructor because it is expensive and we don't need it most of the time. - # users should usually deal with prims directly as they typically need to access the prims directly. if not hasattr(self, "_prim_paths"): self._prim_paths = [prim.GetPath().pathString for prim in self._prims] return self._prim_paths - """ - Operations - Setters. - """ + # ------------------------------------------------------------------ + # Operations – Setters + # ------------------------------------------------------------------ def set_world_poses( self, @@ -253,10 +227,9 @@ def set_world_poses( ValueError: If positions shape is not (M, 3) or orientations shape is not (M, 4). ValueError: If the number of poses doesn't match the number of indices provided. """ - if self._use_fabric: - self._set_world_poses_fabric(positions, orientations, indices) - else: - self._set_world_poses_usd(positions, orientations, indices) + self._backend.set_world_poses(positions, orientations, indices) + for sync in self._sync_backends: + sync.set_world_poses(positions, orientations, indices) def set_local_poses( self, @@ -293,10 +266,9 @@ def set_local_poses( ValueError: If translations shape is not (M, 3) or orientations shape is not (M, 4). ValueError: If the number of poses doesn't match the number of indices provided. """ - if self._use_fabric: - self._set_local_poses_fabric(translations, orientations, indices) - else: - self._set_local_poses_usd(translations, orientations, indices) + self._backend.set_local_poses(translations, orientations, indices) + for sync in self._sync_backends: + sync.set_local_poses(translations, orientations, indices) def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None): """Set scales for prims in the view. @@ -315,10 +287,9 @@ def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) Raises: ValueError: If scales shape is not (M, 3). """ - if self._use_fabric: - self._set_scales_fabric(scales, indices) - else: - self._set_scales_usd(scales, indices) + self._backend.set_scales(scales, indices) + for sync in self._sync_backends: + sync.set_scales(scales, indices) def set_visibility(self, visibility: torch.Tensor, indices: Sequence[int] | None = None): """Set visibility for prims in the view. @@ -334,30 +305,25 @@ def set_visibility(self, visibility: torch.Tensor, indices: Sequence[int] | None Raises: ValueError: If visibility shape is not (M,). """ - # Resolve indices if indices is None or indices == slice(None): indices_list = self._ALL_INDICES else: indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - # Validate inputs if visibility.shape != (len(indices_list),): raise ValueError(f"Expected visibility shape ({len(indices_list)},), got {visibility.shape}.") - # Set visibility for each prim with Sdf.ChangeBlock(): for idx, prim_idx in enumerate(indices_list): - # Convert prim to imageable imageable = UsdGeom.Imageable(self._prims[prim_idx]) - # Set visibility if visibility[idx]: imageable.MakeVisible() else: imageable.MakeInvisible() - """ - Operations - Getters. - """ + # ------------------------------------------------------------------ + # Operations – Getters + # ------------------------------------------------------------------ def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: """Get world-space poses for prims in the view. @@ -382,44 +348,34 @@ def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.T where M is the number of prims queried. - orientations: Torch tensor of shape (M, 4) containing world-space quaternions (w, x, y, z) """ - if self._use_fabric: - return self._get_world_poses_fabric(indices) - else: - return self._get_world_poses_usd(indices) + return self._backend.get_world_poses(indices) def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: """Get local-space poses for prims in the view. - This method retrieves the position and orientation of each prim in local space (relative to - their parent prims). It reads directly from USD's ``xformOp:translate`` and ``xformOp:orient`` attributes. - - Note: - Even in Fabric mode, local pose operations use USD. This behavior is based on Isaac Sim's design - where Fabric is only used for world pose operations. + This method retrieves the position and orientation of each prim in local space + (relative to their parent prims). - Rationale: - - Local pose reads need correct parent-child hierarchy relationships - - USD maintains these relationships correctly and efficiently - - Fabric is optimized for world pose operations, not local hierarchies + When Fabric is enabled, reads ``omni:fabric:localMatrix`` and decomposes it + using GPU batch operations. Otherwise reads USD's ``xformOp:translate`` and + ``xformOp:orient`` via an :class:`UsdGeom.XformCache`. Note: Scale is ignored. The returned poses contain only translation and rotation. Args: - indices: Indices of prims to get poses for. Defaults to None, in which case poses are retrieved - for all prims in the view. + indices: Indices of prims to get poses for. Defaults to None, in which + case poses are retrieved for all prims in the view. Returns: A tuple of (translations, orientations) where: - - translations: Torch tensor of shape (M, 3) containing local-space translations (x, y, z), - where M is the number of prims queried. - - orientations: Torch tensor of shape (M, 4) containing local-space quaternions (w, x, y, z) + - translations: Torch tensor of shape (M, 3) containing local-space + translations (x, y, z), where M is the number of prims queried. + - orientations: Torch tensor of shape (M, 4) containing local-space + quaternions (x, y, z, w). """ - if self._use_fabric: - return self._get_local_poses_fabric(indices) - else: - return self._get_local_poses_usd(indices) + return self._backend.get_local_poses(indices) def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: """Get scales for prims in the view. @@ -437,10 +393,7 @@ def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: Returns: A tensor of shape (M, 3) containing the scales of each prim, where M is the number of prims queried. """ - if self._use_fabric: - return self._get_scales_fabric(indices) - else: - return self._get_scales_usd(indices) + return self._backend.get_scales(indices) def get_visibility(self, indices: Sequence[int] | None = None) -> torch.Tensor: """Get visibility for prims in the view. @@ -455,686 +408,15 @@ def get_visibility(self, indices: Sequence[int] | None = None) -> torch.Tensor: A tensor of shape (M,) containing the visibility of each prim, where M is the number of prims queried. The tensor is of type bool. """ - # Resolve indices if indices is None or indices == slice(None): indices_list = self._ALL_INDICES else: - # Convert to list if it is a tensor array indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - # Create buffers visibility = torch.zeros(len(indices_list), dtype=torch.bool, device=self._device) for idx, prim_idx in enumerate(indices_list): - # Get prim imageable = UsdGeom.Imageable(self._prims[prim_idx]) - # Get visibility visibility[idx] = imageable.ComputeVisibility() != UsdGeom.Tokens.invisible return visibility - - """ - Internal Functions - USD. - """ - - def _set_world_poses_usd( - self, - positions: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set world poses to USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - # Convert to list if it is a tensor array - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Validate inputs - if positions is not None: - if positions.shape != (len(indices_list), 3): - raise ValueError( - f"Expected positions shape ({len(indices_list)}, 3), got {positions.shape}. " - "Number of positions must match the number of prims in the view." - ) - positions_array = Vt.Vec3dArray.FromNumpy(positions.cpu().numpy()) - else: - positions_array = None - if orientations is not None: - if orientations.shape != (len(indices_list), 4): - raise ValueError( - f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}. " - "Number of orientations must match the number of prims in the view." - ) - # Vt expects quaternions in xyzw order - orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) - else: - orientations_array = None - - # Create xform cache instance - xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) - - # Set poses for each prim - # We use Sdf.ChangeBlock to minimize notification overhead. - with Sdf.ChangeBlock(): - for idx, prim_idx in enumerate(indices_list): - # Get prim - prim = self._prims[prim_idx] - # Get parent prim for local space conversion - parent_prim = prim.GetParent() - - # Determine what to set - world_pos = positions_array[idx] if positions_array is not None else None - world_quat = orientations_array[idx] if orientations_array is not None else None - - # Convert world pose to local if we have a valid parent - # Note: We don't use :func:`isaaclab.sim.utils.transforms.convert_world_pose_to_local` - # here since it isn't optimized for batch operations. - if parent_prim.IsValid() and parent_prim.GetPath() != Sdf.Path.absoluteRootPath: - # Get current world pose if we're only setting one component - if positions_array is None or orientations_array is None: - # get prim xform - prim_tf = xform_cache.GetLocalToWorldTransform(prim) - # sanitize quaternion - # this is needed, otherwise the quaternion might be non-normalized - prim_tf.Orthonormalize() - # populate desired world transform - if world_pos is not None: - prim_tf.SetTranslateOnly(world_pos) - if world_quat is not None: - prim_tf.SetRotateOnly(world_quat) - else: - # Both position and orientation are provided, create new transform - prim_tf = Gf.Matrix4d() - prim_tf.SetTranslateOnly(world_pos) - prim_tf.SetRotateOnly(world_quat) - - # Convert to local space - parent_world_tf = xform_cache.GetLocalToWorldTransform(parent_prim) - local_tf = prim_tf * parent_world_tf.GetInverse() - local_pos = local_tf.ExtractTranslation() - local_quat = local_tf.ExtractRotationQuat() - else: - # No parent or parent is root, world == local - local_pos = world_pos - local_quat = world_quat - - # Get or create the standard transform operations - if local_pos is not None: - prim.GetAttribute("xformOp:translate").Set(local_pos) - if local_quat is not None: - prim.GetAttribute("xformOp:orient").Set(local_quat) - - def _set_local_poses_usd( - self, - translations: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set local poses to USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Validate inputs - if translations is not None: - if translations.shape != (len(indices_list), 3): - raise ValueError(f"Expected translations shape ({len(indices_list)}, 3), got {translations.shape}.") - translations_array = Vt.Vec3dArray.FromNumpy(translations.cpu().numpy()) - else: - translations_array = None - if orientations is not None: - if orientations.shape != (len(indices_list), 4): - raise ValueError(f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}.") - orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) - else: - orientations_array = None - - # Set local poses - with Sdf.ChangeBlock(): - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - if translations_array is not None: - prim.GetAttribute("xformOp:translate").Set(translations_array[idx]) - if orientations_array is not None: - prim.GetAttribute("xformOp:orient").Set(orientations_array[idx]) - - def _set_scales_usd(self, scales: torch.Tensor, indices: Sequence[int] | None = None): - """Set scales to USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Validate inputs - if scales.shape != (len(indices_list), 3): - raise ValueError(f"Expected scales shape ({len(indices_list)}, 3), got {scales.shape}.") - - scales_array = Vt.Vec3dArray.FromNumpy(scales.cpu().numpy()) - # Set scales for each prim - with Sdf.ChangeBlock(): - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - prim.GetAttribute("xformOp:scale").Set(scales_array[idx]) - - def _get_world_poses_usd(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get world poses from USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - # Convert to list if it is a tensor array - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Create buffers - positions = Vt.Vec3dArray(len(indices_list)) - orientations = Vt.QuatdArray(len(indices_list)) - # Create xform cache instance - xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) - - # Note: We don't use :func:`isaaclab.sim.utils.transforms.resolve_prim_pose` - # here since it isn't optimized for batch operations. - for idx, prim_idx in enumerate(indices_list): - # Get prim - prim = self._prims[prim_idx] - # get prim xform - prim_tf = xform_cache.GetLocalToWorldTransform(prim) - # sanitize quaternion - # this is needed, otherwise the quaternion might be non-normalized - prim_tf.Orthonormalize() - # extract position and orientation - positions[idx] = prim_tf.ExtractTranslation() - orientations[idx] = prim_tf.ExtractRotationQuat() - - # move to torch tensors - positions = torch.tensor(np.array(positions), dtype=torch.float32, device=self._device) - orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) - return positions, orientations # type: ignore - - def _get_local_poses_usd(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get local poses from USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Create buffers - translations = Vt.Vec3dArray(len(indices_list)) - orientations = Vt.QuatdArray(len(indices_list)) - - # Create a fresh XformCache to avoid stale cached values - xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) - - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - prim_tf = xform_cache.GetLocalTransformation(prim)[0] - prim_tf.Orthonormalize() - translations[idx] = prim_tf.ExtractTranslation() - orientations[idx] = prim_tf.ExtractRotationQuat() - - translations = torch.tensor(np.array(translations), dtype=torch.float32, device=self._device) - orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) - return translations, orientations # type: ignore - - def _get_scales_usd(self, indices: Sequence[int] | None = None) -> torch.Tensor: - """Get scales from USD.""" - # Resolve indices - if indices is None or indices == slice(None): - indices_list = self._ALL_INDICES - else: - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - - # Create buffers - scales = Vt.Vec3dArray(len(indices_list)) - - for idx, prim_idx in enumerate(indices_list): - prim = self._prims[prim_idx] - scales[idx] = prim.GetAttribute("xformOp:scale").Get() - - # Convert to tensor - return torch.tensor(np.array(scales), dtype=torch.float32, device=self._device) - - """ - Internal Functions - Fabric. - """ - - def _set_world_poses_fabric( - self, - positions: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set world poses using Fabric GPU batch operations. - - Writes directly to Fabric's ``omni:fabric:worldMatrix`` attribute using Warp kernels. - Changes are propagated through Fabric's hierarchy system but remain GPU-resident. - - For workflows mixing Fabric world pose writes with USD local pose queries, note - that local poses read from USD's xformOp:* attributes, which may not immediately - reflect Fabric changes. For best performance and consistency, use Fabric methods - exclusively (get_world_poses/set_world_poses with Fabric enabled). - """ - # Lazy initialization - if not self._fabric_initialized: - self._initialize_fabric() - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Convert torch to warp (if provided), use dummy arrays for None to avoid Warp kernel issues - if positions is not None: - positions_wp = wp.from_torch(positions) - else: - positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - if orientations is not None: - orientations_wp = wp.from_torch(orientations) - else: - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - - # Dummy array for scales (not modifying) - scales_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - # Use cached fabricarray for world matrices - world_matrices = self._fabric_world_matrices - - # Batch compose matrices with a single kernel launch - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, - orientations_wp, - scales_wp, # dummy array instead of None - False, # broadcast_positions - False, # broadcast_orientations - False, # broadcast_scales - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Synchronize to ensure kernel completes - wp.synchronize() - - # Update world transforms within Fabric hierarchy - self._fabric_hierarchy.update_world_xforms() - # Fabric now has authoritative data; skip future USD syncs - self._fabric_usd_sync_done = True - # Mirror to USD for renderer-facing prims when enabled. - if self._sync_usd_on_fabric_write: - self._set_world_poses_usd(positions, orientations, indices) - - # Fabric writes are GPU-resident; local pose operations still use USD. - - def _set_local_poses_fabric( - self, - translations: torch.Tensor | None = None, - orientations: torch.Tensor | None = None, - indices: Sequence[int] | None = None, - ): - """Set local poses using USD (matches Isaac Sim's design). - - Note: Even in Fabric mode, local pose operations use USD. - This is Isaac Sim's design: the ``usd=False`` parameter only affects world poses. - - Rationale: - - Local pose writes need correct parent-child hierarchy relationships - - USD maintains these relationships correctly and efficiently - - Fabric is optimized for world pose operations, not local hierarchies - """ - self._set_local_poses_usd(translations, orientations, indices) - - def _set_scales_fabric(self, scales: torch.Tensor, indices: Sequence[int] | None = None): - """Set scales using Fabric GPU batch operations.""" - # Lazy initialization - if not self._fabric_initialized: - self._initialize_fabric() - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Convert torch to warp - scales_wp = wp.from_torch(scales) - - # Dummy arrays for positions and orientations (not modifying) - positions_wp = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - orientations_wp = wp.zeros((0, 4), dtype=wp.float32).to(self._device) - - # Use cached fabricarray for world matrices - world_matrices = self._fabric_world_matrices - - # Batch compose matrices on GPU with a single kernel launch - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, # dummy array instead of None - orientations_wp, # dummy array instead of None - scales_wp, - False, # broadcast_positions - False, # broadcast_orientations - False, # broadcast_scales - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Synchronize to ensure kernel completes before syncing - wp.synchronize() - - # Update world transforms to propagate changes - self._fabric_hierarchy.update_world_xforms() - # Fabric now has authoritative data; skip future USD syncs - self._fabric_usd_sync_done = True - # Mirror to USD for renderer-facing prims when enabled. - if self._sync_usd_on_fabric_write: - self._set_scales_usd(scales, indices) - - def _get_world_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get world poses from Fabric using GPU batch operations.""" - # Lazy initialization of Fabric infrastructure - if not self._fabric_initialized: - self._initialize_fabric() - # Sync once from USD to ensure reads see the latest authored transforms - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Use pre-allocated buffers for full reads, allocate only for partial reads - use_cached_buffers = indices is None or indices == slice(None) - if use_cached_buffers: - # Full read: Use cached buffers (zero allocation overhead!) - positions_wp = self._fabric_positions_buffer - orientations_wp = self._fabric_orientations_buffer - scales_wp = self._fabric_dummy_buffer - else: - # Partial read: Need to allocate buffers of appropriate size - positions_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) - orientations_wp = wp.zeros((count, 4), dtype=wp.float32).to(self._device) - scales_wp = self._fabric_dummy_buffer # Always use dummy for scales - - # Use cached fabricarray for world matrices - # This eliminates the 0.06-0.30ms variability from creating fabricarray each call - world_matrices = self._fabric_world_matrices - - # Launch GPU kernel to decompose matrices in parallel - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, - orientations_wp, - scales_wp, # dummy array instead of None - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Return tensors: zero-copy for cached buffers, conversion for partial reads - if use_cached_buffers: - # Zero-copy! The Warp kernel wrote directly into the PyTorch tensors - # We just need to synchronize to ensure the kernel is done - wp.synchronize() - return self._fabric_positions_torch, self._fabric_orientations_torch - else: - # Partial read: Need to convert from Warp to torch - positions = wp.to_torch(positions_wp) - orientations = wp.to_torch(orientations_wp) - return positions, orientations - - def _get_local_poses_fabric(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: - """Get local poses using USD (matches Isaac Sim's design). - - Note: - Even in Fabric mode, local pose operations use USD's XformCache. - This is Isaac Sim's design: the ``usd=False`` parameter only affects world poses. - - Rationale: - - Local pose computation requires parent transforms which may not be in the view - - USD's XformCache provides efficient hierarchy-aware local transform queries - - Fabric is optimized for world pose operations, not local hierarchies - """ - return self._get_local_poses_usd(indices) - - def _get_scales_fabric(self, indices: Sequence[int] | None = None) -> torch.Tensor: - """Get scales from Fabric using GPU batch operations.""" - # Lazy initialization - if not self._fabric_initialized: - self._initialize_fabric() - # Sync once from USD to ensure reads see the latest authored transforms - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() - - # Resolve indices (treat slice(None) as None for consistency with USD path) - indices_wp = self._resolve_indices_wp(indices) - - count = indices_wp.shape[0] - - # Use pre-allocated buffers for full reads, allocate only for partial reads - use_cached_buffers = indices is None or indices == slice(None) - if use_cached_buffers: - # Full read: Use cached buffers (zero allocation overhead!) - scales_wp = self._fabric_scales_buffer - else: - # Partial read: Need to allocate buffer of appropriate size - scales_wp = wp.zeros((count, 3), dtype=wp.float32).to(self._device) - - # Always use dummy buffers for positions and orientations (not needed for scales) - positions_wp = self._fabric_dummy_buffer - orientations_wp = self._fabric_dummy_buffer - - # Use cached fabricarray for world matrices - world_matrices = self._fabric_world_matrices - - # Launch GPU kernel to decompose matrices in parallel - # Note: world_matrices is a fabricarray on fabric_device, so we must launch on fabric_device - wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, - dim=count, - inputs=[ - world_matrices, - positions_wp, # dummy array instead of None - orientations_wp, # dummy array instead of None - scales_wp, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - - # Return tensor: zero-copy for cached buffers, conversion for partial reads - if use_cached_buffers: - # Zero-copy! The Warp kernel wrote directly into the PyTorch tensor - wp.synchronize() - return self._fabric_scales_torch - else: - # Partial read: Need to convert from Warp to torch - return wp.to_torch(scales_wp) - - """ - Internal Functions - Initialization. - """ - - def _initialize_fabric(self) -> None: - """Initialize Fabric batch infrastructure for GPU-accelerated pose queries. - - This method ensures all prims have the required Fabric hierarchy attributes - (``omni:fabric:localMatrix`` and ``omni:fabric:worldMatrix``) and creates the necessary - infrastructure for batch GPU operations using Warp. - - Based on the Fabric Hierarchy documentation, when Fabric Scene Delegate is enabled, - all boundable prims should have these attributes. This method ensures they exist - and are properly synchronized with USD. - """ - import usdrt - from usdrt import Rt - - # Get USDRT (Fabric) stage - stage_id = sim_utils.get_current_stage_id() - fabric_stage = usdrt.Usd.Stage.Attach(stage_id) - - # Step 1: Ensure all prims have Fabric hierarchy attributes - # According to the documentation, these attributes are created automatically - # when Fabric Scene Delegate is enabled, but we ensure they exist - for i in range(self.count): - rt_prim = fabric_stage.GetPrimAtPath(self.prim_paths[i]) - rt_xformable = Rt.Xformable(rt_prim) - - # Create Fabric hierarchy world matrix attribute if it doesn't exist - has_attr = ( - rt_xformable.HasFabricHierarchyWorldMatrixAttr() - if hasattr(rt_xformable, "HasFabricHierarchyWorldMatrixAttr") - else False - ) - if not has_attr: - rt_xformable.CreateFabricHierarchyWorldMatrixAttr() - - # Best-effort USD->Fabric sync; authoritative initialization happens on first read. - rt_xformable.SetWorldXformFromUsd() - - # Create view index attribute for batch operations - rt_prim.CreateAttribute(self._view_index_attr, usdrt.Sdf.ValueTypeNames.UInt, custom=True) - rt_prim.GetAttribute(self._view_index_attr).Set(i) - - # After syncing all prims, update the Fabric hierarchy to ensure world matrices are computed - self._fabric_hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( - fabric_stage.GetFabricId(), fabric_stage.GetStageIdAsStageId() - ) - self._fabric_hierarchy.update_world_xforms() - - # Step 2: Create index arrays for batch operations - self._default_view_indices = wp.zeros((self.count,), dtype=wp.uint32).to(self._device) - wp.launch( - kernel=fabric_utils.arange_k, - dim=self.count, - inputs=[self._default_view_indices], - device=self._device, - ) - wp.synchronize() # Ensure indices are ready - - # Step 3: Create Fabric selection with attribute filtering - # SelectPrims expects device format like "cuda:0" not "cuda" - # - # KNOWN ISSUE: SelectPrims may return prims in a different order than self._prims - # (which comes from USD's find_matching_prims). We create a bidirectional mapping - # (_view_to_fabric and _fabric_to_view) to handle this ordering difference. - # This works correctly for full-view operations but partial indexing still has issues. - # - # NOTE: SelectPrims only supports "cuda:0" regardless of which GPU the simulation - # is running on. In multi-GPU setups, we must use "cuda:0" for SelectPrims even if - # the simulation device is "cuda:1" or higher. - fabric_device = self._device - if self._device == "cuda": - logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.") - fabric_device = "cuda:0" - elif self._device.startswith("cuda:"): - # SelectPrims only supports cuda:0, so we always use cuda:0 for SelectPrims - # even if the simulation is running on a different GPU - if self._device != "cuda:0": - logger.debug( - f"SelectPrims only supports cuda:0. Using cuda:0 for SelectPrims " - f"even though simulation device is {self._device}." - ) - fabric_device = "cuda:0" - - self._fabric_selection = fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, "omni:fabric:worldMatrix", usdrt.Usd.Access.ReadWrite), - ], - device=fabric_device, - ) - - # Step 4: Create bidirectional mapping between view and fabric indices - # Note: fabric_to_view is tied to fabric_device (cuda:0) because it's created from SelectPrims. - # view_to_fabric must also be on fabric_device since it's always used with fabricarrays in kernels. - self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32).to(fabric_device) - self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr) - - wp.launch( - kernel=fabric_utils.set_view_to_fabric_array, - dim=self._fabric_to_view.shape[0], - inputs=[self._fabric_to_view, self._view_to_fabric], - device=fabric_device, - ) - # Synchronize to ensure mapping is ready before any operations - wp.synchronize() - - # Pre-allocate reusable output buffers for read operations - self._fabric_positions_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) - self._fabric_orientations_torch = torch.zeros((self.count, 4), dtype=torch.float32, device=self._device) - self._fabric_scales_torch = torch.zeros((self.count, 3), dtype=torch.float32, device=self._device) - - # Create Warp views of the PyTorch tensors - self._fabric_positions_buffer = wp.from_torch(self._fabric_positions_torch, dtype=wp.float32) - self._fabric_orientations_buffer = wp.from_torch(self._fabric_orientations_torch, dtype=wp.float32) - self._fabric_scales_buffer = wp.from_torch(self._fabric_scales_torch, dtype=wp.float32) - - # Dummy array for unused outputs (always empty) - self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32).to(self._device) - - # Cache fabricarray for world matrices to avoid recreation overhead - # Refs: https://docs.omniverse.nvidia.com/kit/docs/usdrt/latest/docs/usdrt_prim_selection.html - # https://docs.omniverse.nvidia.com/kit/docs/usdrt/latest/docs/scenegraph_use.html - self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") - - # Cache Fabric stage to avoid expensive get_current_stage() calls - self._fabric_stage = fabric_stage - - # Store fabric_device for use in kernel launches that involve fabricarrays - self._fabric_device = fabric_device - - self._fabric_initialized = True - # Force a one-time USD->Fabric sync on first read to pick up any USD edits - # made after the view was constructed. - self._fabric_usd_sync_done = False - - def _sync_fabric_from_usd_once(self) -> None: - """Sync Fabric world matrices from USD once, on the first read.""" - # Ensure Fabric is initialized - if not self._fabric_initialized: - self._initialize_fabric() - - # Read authoritative transforms from USD and write once into Fabric. - positions_usd, orientations_usd = self._get_world_poses_usd() - scales_usd = self._get_scales_usd() - - prev_sync = self._sync_usd_on_fabric_write - self._sync_usd_on_fabric_write = False - self._set_world_poses_fabric(positions_usd, orientations_usd) - self._set_scales_fabric(scales_usd) - self._sync_usd_on_fabric_write = prev_sync - - self._fabric_usd_sync_done = True - - def _resolve_indices_wp(self, indices: Sequence[int] | None) -> wp.array: - """Resolve view indices as a Warp array.""" - if indices is None or indices == slice(None): - if self._default_view_indices is None: - raise RuntimeError("Fabric indices are not initialized.") - return self._default_view_indices - indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) - return wp.array(indices_list, dtype=wp.uint32).to(self._device) diff --git a/source/isaaclab/isaaclab/sim/views/xform_usd_backend.py b/source/isaaclab/isaaclab/sim/views/xform_usd_backend.py new file mode 100644 index 000000000000..f03b59e2a21e --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_usd_backend.py @@ -0,0 +1,220 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +import torch + +from pxr import Gf, Sdf, Usd, UsdGeom, Vt + + +class UsdBackend: + """USD-based transform backend for :class:`XformPrimView`. + + Reads and writes transforms through USD's ``xformOp`` attributes, using + :class:`UsdGeom.XformCache` for efficient world-space computations. + """ + + def __init__(self, prims: list[Usd.Prim], device: str): + self._prims = prims + self._device = device + self._ALL_INDICES = list(range(len(prims))) + + @property + def count(self) -> int: + """Number of prims managed by this backend.""" + return len(self._prims) + + def initialize(self) -> None: + """No-op for the USD backend (no deferred setup needed).""" + + # ------------------------------------------------------------------ + # Setters + # ------------------------------------------------------------------ + + def set_world_poses( + self, + positions: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set world-space poses by converting to local space and writing to USD.""" + # Resolve indices + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + # Validate inputs + if positions is not None: + if positions.shape != (len(indices_list), 3): + raise ValueError( + f"Expected positions shape ({len(indices_list)}, 3), got {positions.shape}. " + "Number of positions must match the number of prims in the view." + ) + positions_array = Vt.Vec3dArray.FromNumpy(positions.cpu().numpy()) + else: + positions_array = None + if orientations is not None: + if orientations.shape != (len(indices_list), 4): + raise ValueError( + f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}. " + "Number of orientations must match the number of prims in the view." + ) + orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) + else: + orientations_array = None + + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + parent_prim = prim.GetParent() + + world_pos = positions_array[idx] if positions_array is not None else None + world_quat = orientations_array[idx] if orientations_array is not None else None + + if parent_prim.IsValid() and parent_prim.GetPath() != Sdf.Path.absoluteRootPath: + if positions_array is None or orientations_array is None: + prim_tf = xform_cache.GetLocalToWorldTransform(prim) + prim_tf.Orthonormalize() + if world_pos is not None: + prim_tf.SetTranslateOnly(world_pos) + if world_quat is not None: + prim_tf.SetRotateOnly(world_quat) + else: + prim_tf = Gf.Matrix4d() + prim_tf.SetTranslateOnly(world_pos) + prim_tf.SetRotateOnly(world_quat) + + parent_world_tf = xform_cache.GetLocalToWorldTransform(parent_prim) + local_tf = prim_tf * parent_world_tf.GetInverse() + local_pos = local_tf.ExtractTranslation() + local_quat = local_tf.ExtractRotationQuat() + else: + local_pos = world_pos + local_quat = world_quat + + if local_pos is not None: + prim.GetAttribute("xformOp:translate").Set(local_pos) + if local_quat is not None: + prim.GetAttribute("xformOp:orient").Set(local_quat) + + def set_local_poses( + self, + translations: torch.Tensor | None = None, + orientations: torch.Tensor | None = None, + indices: Sequence[int] | None = None, + ) -> None: + """Set local-space poses directly on USD xformOp attributes.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + if translations is not None: + if translations.shape != (len(indices_list), 3): + raise ValueError(f"Expected translations shape ({len(indices_list)}, 3), got {translations.shape}.") + translations_array = Vt.Vec3dArray.FromNumpy(translations.cpu().numpy()) + else: + translations_array = None + if orientations is not None: + if orientations.shape != (len(indices_list), 4): + raise ValueError(f"Expected orientations shape ({len(indices_list)}, 4), got {orientations.shape}.") + orientations_array = Vt.QuatdArray.FromNumpy(orientations.cpu().numpy()) + else: + orientations_array = None + + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + if translations_array is not None: + prim.GetAttribute("xformOp:translate").Set(translations_array[idx]) + if orientations_array is not None: + prim.GetAttribute("xformOp:orient").Set(orientations_array[idx]) + + def set_scales(self, scales: torch.Tensor, indices: Sequence[int] | None = None) -> None: + """Set scales on USD ``xformOp:scale`` attributes.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + if scales.shape != (len(indices_list), 3): + raise ValueError(f"Expected scales shape ({len(indices_list)}, 3), got {scales.shape}.") + + scales_array = Vt.Vec3dArray.FromNumpy(scales.cpu().numpy()) + + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + prim.GetAttribute("xformOp:scale").Set(scales_array[idx]) + + # ------------------------------------------------------------------ + # Getters + # ------------------------------------------------------------------ + + def get_world_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Get world-space poses via :class:`UsdGeom.XformCache`.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + positions = Vt.Vec3dArray(len(indices_list)) + orientations = Vt.QuatdArray(len(indices_list)) + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + prim_tf = xform_cache.GetLocalToWorldTransform(prim) + prim_tf.Orthonormalize() + positions[idx] = prim_tf.ExtractTranslation() + orientations[idx] = prim_tf.ExtractRotationQuat() + + positions = torch.tensor(np.array(positions), dtype=torch.float32, device=self._device) + orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) + return positions, orientations # type: ignore + + def get_local_poses(self, indices: Sequence[int] | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """Get local-space poses via :class:`UsdGeom.XformCache`.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + translations = Vt.Vec3dArray(len(indices_list)) + orientations = Vt.QuatdArray(len(indices_list)) + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + prim_tf = xform_cache.GetLocalTransformation(prim)[0] + prim_tf.Orthonormalize() + translations[idx] = prim_tf.ExtractTranslation() + orientations[idx] = prim_tf.ExtractRotationQuat() + + translations = torch.tensor(np.array(translations), dtype=torch.float32, device=self._device) + orientations = torch.tensor(np.array(orientations), dtype=torch.float32, device=self._device) + return translations, orientations # type: ignore + + def get_scales(self, indices: Sequence[int] | None = None) -> torch.Tensor: + """Get scales from USD ``xformOp:scale`` attributes.""" + if indices is None or indices == slice(None): + indices_list = self._ALL_INDICES + else: + indices_list = indices.tolist() if isinstance(indices, torch.Tensor) else list(indices) + + scales = Vt.Vec3dArray(len(indices_list)) + + for idx, prim_idx in enumerate(indices_list): + prim = self._prims[prim_idx] + scales[idx] = prim.GetAttribute("xformOp:scale").Get() + + return torch.tensor(np.array(scales), dtype=torch.float32, device=self._device) diff --git a/source/isaaclab/isaaclab/utils/warp/fabric.py b/source/isaaclab/isaaclab/utils/warp/fabric.py index a48f773f4991..e38ea0e7b337 100644 --- a/source/isaaclab/isaaclab/utils/warp/fabric.py +++ b/source/isaaclab/isaaclab/utils/warp/fabric.py @@ -18,12 +18,14 @@ if TYPE_CHECKING: FabricArrayUInt32 = Any FabricArrayMat44d = Any + IndexedFabricArrayMat44d = Any ArrayUInt32 = Any ArrayUInt32_1d = Any ArrayFloat32_2d = Any else: FabricArrayUInt32 = wp.fabricarray(dtype=wp.uint32) FabricArrayMat44d = wp.fabricarray(dtype=wp.mat44d) + IndexedFabricArrayMat44d = wp.indexedfabricarray(dtype=wp.mat44d) ArrayUInt32 = wp.array(ndim=1, dtype=wp.uint32) ArrayUInt32_1d = wp.array(dtype=wp.uint32) ArrayFloat32_2d = wp.array(ndim=2, dtype=wp.float32) @@ -37,13 +39,6 @@ def set_view_to_fabric_array(fabric_to_view: FabricArrayUInt32, view_to_fabric: view_to_fabric[view_idx] = wp.uint32(fabric_idx) -@wp.kernel(enable_backward=False) -def arange_k(a: ArrayUInt32_1d): - """Fill array with sequential indices.""" - tid = int(wp.tid()) - a[tid] = wp.uint32(tid) - - @wp.kernel(enable_backward=False) def decompose_fabric_transformation_matrix_to_warp_arrays( fabric_matrices: FabricArrayMat44d, @@ -163,6 +158,109 @@ def compose_fabric_transformation_matrix_from_warp_arrays( ) +@wp.kernel(enable_backward=False) +def decompose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + indices: ArrayUInt32, +): + """Decompose indexed Fabric transformation matrices into position, orientation, and scale. + + Like :func:`decompose_fabric_transformation_matrix_to_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices. + array_positions: Output array for positions [m], shape (N, 3). + array_orientations: Output array for quaternions in xyzw format, shape (N, 4). + array_scales: Output array for scales, shape (N, 3). + indices: View indices to process (subset selection). + """ + output_index = wp.tid() + view_index = indices[output_index] + + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + array_positions[output_index, 0] = position[0] + array_positions[output_index, 1] = position[1] + array_positions[output_index, 2] = position[2] + if array_orientations.shape[0] > 0: + array_orientations[output_index, 0] = rotation[0] + array_orientations[output_index, 1] = rotation[1] + array_orientations[output_index, 2] = rotation[2] + array_orientations[output_index, 3] = rotation[3] + if array_scales.shape[0] > 0: + array_scales[output_index, 0] = scale[0] + array_scales[output_index, 1] = scale[1] + array_scales[output_index, 2] = scale[2] + + +@wp.kernel(enable_backward=False) +def compose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + broadcast_positions: bool, + broadcast_orientations: bool, + broadcast_scales: bool, + indices: ArrayUInt32, +): + """Compose indexed Fabric transformation matrices from position, orientation, and scale. + + Like :func:`compose_fabric_transformation_matrix_from_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices to update. + array_positions: Input array for positions [m], shape (N, 3). + array_orientations: Input array for quaternions in xyzw format, shape (N, 4). + array_scales: Input array for scales, shape (N, 3). + broadcast_positions: If True, use first position for all prims. + broadcast_orientations: If True, use first orientation for all prims. + broadcast_scales: If True, use first scale for all prims. + indices: View indices to process (subset selection). + """ + i = wp.tid() + view_index = indices[i] + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + if broadcast_positions: + index = 0 + else: + index = i + position[0] = array_positions[index, 0] + position[1] = array_positions[index, 1] + position[2] = array_positions[index, 2] + if array_orientations.shape[0] > 0: + if broadcast_orientations: + index = 0 + else: + index = i + rotation[0] = array_orientations[index, 0] + rotation[1] = array_orientations[index, 1] + rotation[2] = array_orientations[index, 2] + rotation[3] = array_orientations[index, 3] + if array_scales.shape[0] > 0: + if broadcast_scales: + index = 0 + else: + index = i + scale[0] = array_scales[index, 0] + scale[1] = array_scales[index, 1] + scale[2] = array_scales[index, 2] + + fabric_matrices[view_index] = wp.mat44d( # type: ignore[arg-type] + wp.transpose(wp.transform_compose(position, rotation, scale)) # type: ignore[arg-type] + ) + + @wp.func def _decompose_transformation_matrix(m: Any): # -> tuple[wp.vec3f, wp.quatf, wp.vec3f] """Decompose a 4x4 transformation matrix into position, orientation, and scale. diff --git a/source/isaaclab/test/sensors/test_tiled_camera.py b/source/isaaclab/test/sensors/test_tiled_camera.py index 0fcf5f26c612..4bece04b9800 100644 --- a/source/isaaclab/test/sensors/test_tiled_camera.py +++ b/source/isaaclab/test/sensors/test_tiled_camera.py @@ -1615,6 +1615,112 @@ def test_output_equal_to_usd_camera_intrinsics(setup_camera, device): del camera_usd +@pytest.mark.parametrize( + "camera_cls,cfg_cls", + [(TiledCamera, TiledCameraCfg), (Camera, CameraCfg)], + ids=["tiled", "non_tiled"], +) +@pytest.mark.parametrize( + "device", + [ + "cpu", + "cuda:0", + ], +) +@pytest.mark.isaacsim_ci +def test_camera_pose_update_reflected_in_render(setup_camera, device, camera_cls, cfg_cls): + """Test that moving a camera is reflected in rendered depth. + + Both camera types must produce different depth images when the + camera is repositioned from close to far. + """ + sim, __unused_camera_cfg, dt = setup_camera + cam_cfg = cfg_cls( + prim_path="/World/PoseTestCam", + height=128, + width=256, + update_period=0, + update_latest_camera_pose=True, + data_types=["distance_to_camera"], + spawn=sim_utils.PinholeCameraCfg( + focal_length=24.0, + focus_distance=400.0, + horizontal_aperture=20.955, + clipping_range=(0.1, 1.0e5), + ), + ) + camera = camera_cls(cam_cfg) + + sim.reset() + + target = torch.tensor( + [[0.0, 0.0, 0.0]], + dtype=torch.float32, + device=camera.device, + ) + + # Position A: close to scene objects + eyes_close = torch.tensor( + [[2.0, 2.0, 2.0]], + dtype=torch.float32, + device=camera.device, + ) + camera.set_world_poses_from_view(eyes_close, target) + sim.step() + camera.update(dt) + depth_close = camera.data.output["distance_to_camera"].clone() + + # Position B: far from scene objects + eyes_far = torch.tensor( + [[8.0, 8.0, 8.0]], + dtype=torch.float32, + device=camera.device, + ) + camera.set_world_poses_from_view(eyes_far, target) + sim.step() + camera.update(dt) + depth_far = camera.data.output["distance_to_camera"].clone() + + max_range = cam_cfg.spawn.clipping_range[1] + + def _save_depth_image(depth: torch.Tensor, filename: str): + """Save a depth tensor as a min-max normalized grayscale PNG for visual comparison.""" + img = depth[0].squeeze().cpu().float() + valid = img[img < max_range] + if valid.numel() > 0: + lo, hi = valid.min(), valid.max() + img = img.clamp(lo, hi) + img = (img - lo) / (hi - lo + 1e-6) + else: + img = img.clamp(0, max_range) / max_range + img = (img * 255).to(torch.uint8).numpy() + from PIL import Image + + Image.fromarray(img).save(filename) + print(f"[DEBUG] Saved depth image: {filename}") + + cam_type = "tiled" if camera_cls is TiledCamera else "non_tiled" + _save_depth_image(depth_close, f"/tmp/depth_{cam_type}_{device}_close.png") + _save_depth_image(depth_far, f"/tmp/depth_{cam_type}_{device}_far.png") + + valid_close = depth_close[depth_close < max_range] + valid_far = depth_far[depth_far < max_range] + + assert valid_close.numel() > 0, "No valid depth pixels from close position" + assert valid_far.numel() > 0, "No valid depth pixels from far position" + + mean_close = valid_close.mean().item() + mean_far = valid_far.mean().item() + + assert mean_far > mean_close * 1.5, ( + f"Mean depth from far ({mean_far:.2f}) should be" + f" >= 1.5x close ({mean_close:.2f}). Renderer may" + " not be observing the updated camera pose." + ) + + del camera + + @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) @pytest.mark.isaacsim_ci def test_sensor_print(setup_camera, device): diff --git a/source/isaaclab/test/sim/test_views_xform_prim.py b/source/isaaclab/test/sim/test_views_xform_prim.py index 3de7a0b357a2..58e9271cd8fd 100644 --- a/source/isaaclab/test/sim/test_views_xform_prim.py +++ b/source/isaaclab/test/sim/test_views_xform_prim.py @@ -7,8 +7,9 @@ from isaaclab.app import AppLauncher -# launch omniverse app -simulation_app = AppLauncher(headless=True).app +# In order to test Fabric backend we need to enable cameras. This setting will enable +# Fabric Scene Delegate, allowing us to test Fabric operations, such as hierarchy updates. +simulation_app = AppLauncher(headless=True, enable_cameras=True).app """Rest everything follows.""" @@ -30,7 +31,7 @@ def test_setup_teardown(): """Create a blank new stage for each test.""" # Setup: Create a new stage - sim_utils.create_new_stage() + sim_utils.create_new_stage(create_fabric_stage=True) sim_utils.update_stage() # Yield for the test @@ -62,8 +63,6 @@ def _skip_if_backend_unavailable(backend: str, device: str): """Skip tests when the requested backend is unavailable.""" if device.startswith("cuda") and not torch.cuda.is_available(): pytest.skip("CUDA not available") - if backend == "fabric" and device == "cpu": - pytest.skip("Warp fabricarray operations on CPU have known issues") def _prim_type_for_backend(backend: str) -> str: @@ -238,7 +237,7 @@ def test_xform_prim_view_initialization_empty_pattern(device): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - sim_utils.create_new_stage() + sim_utils.create_new_stage(create_fabric_stage=True) # Create view with pattern that matches nothing view = XformPrimView("/World/NonExistent_.*", device=device) @@ -1513,3 +1512,68 @@ def test_fabric_usd_consistency(device): pos_fabric_after, quat_fabric_after = view_fabric.get_world_poses() torch.testing.assert_close(pos_fabric_after, new_positions, atol=1e-4, rtol=0) torch.testing.assert_close(quat_fabric_after, new_orientations, atol=1e-4, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_fabric_topology_change_write_lands_on_correct_prim(device): + """Test that writes land on the correct prim after Fabric topology changes. + + Creates 3 Camera prims, builds a view for Cam_0 and Cam_2 only, then + adds a custom attribute to Cam_1 (forcing a bucket reorganization in + Fabric). After that, writes new positions through the view and verifies + via raw usdrt reads that the correct prims received the data. + """ + import usdrt + + _skip_if_backend_unavailable("fabric", device) + + stage = sim_utils.get_current_stage() + + # Step 1: create 3 Camera prims with known positions + sim_utils.create_prim("/World/Cam_0", "Camera", translation=(1.0, 0.0, 0.0), stage=stage) + sim_utils.create_prim("/World/Cam_1", "Camera", translation=(2.0, 0.0, 0.0), stage=stage) + sim_utils.create_prim("/World/Cam_2", "Camera", translation=(3.0, 0.0, 0.0), stage=stage) + + # Step 2: create a Fabric view for Cam_0 and Cam_2 only + sim_ctx = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view = XformPrimView("/World/Cam_[02]", device=device) + assert view.count == 2 + assert set(view.prim_paths) == {"/World/Cam_0", "/World/Cam_2"} + + # Step 3: add a custom attribute to Cam_1 to force a bucket change + fab_stage = usdrt.Usd.Stage.Attach(sim_utils.get_current_stage_id()) + cam1 = fab_stage.GetPrimAtPath(usdrt.Sdf.Path("/World/Cam_1")) + cam1.CreateAttribute("custom:dummyFloat", usdrt.Sdf.ValueTypeNames.Float, True).Set(42.0) + + # Step 4: write new positions through the view + new_positions = torch.tensor( + [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], + dtype=torch.float32, + device=device, + ) + view.set_world_poses(positions=new_positions) + + # Step 5: verify via raw usdrt that the correct prims got the new data + def _read_fabric_position(prim_path: str) -> tuple[float, float, float]: + prim = fab_stage.GetPrimAtPath(usdrt.Sdf.Path(prim_path)) + mat = prim.GetAttribute("omni:fabric:worldMatrix").Get() + return (mat[3][0], mat[3][1], mat[3][2]) + + pos_cam0 = _read_fabric_position("/World/Cam_0") + pos_cam1 = _read_fabric_position("/World/Cam_1") + pos_cam2 = _read_fabric_position("/World/Cam_2") + + # Cam_0 should have (10, 20, 30) + assert abs(pos_cam0[0] - 10.0) < 0.01, f"Cam_0 x: expected 10.0, got {pos_cam0[0]}" + assert abs(pos_cam0[1] - 20.0) < 0.01, f"Cam_0 y: expected 20.0, got {pos_cam0[1]}" + assert abs(pos_cam0[2] - 30.0) < 0.01, f"Cam_0 z: expected 30.0, got {pos_cam0[2]}" + + # Cam_2 should have (40, 50, 60) + assert abs(pos_cam2[0] - 40.0) < 0.01, f"Cam_2 x: expected 40.0, got {pos_cam2[0]}" + assert abs(pos_cam2[1] - 50.0) < 0.01, f"Cam_2 y: expected 50.0, got {pos_cam2[1]}" + assert abs(pos_cam2[2] - 60.0) < 0.01, f"Cam_2 z: expected 60.0, got {pos_cam2[2]}" + + # Cam_1 should NOT have been touched (still at its original position) + assert abs(pos_cam1[0] - 2.0) < 0.01 or abs(pos_cam1[0]) < 0.01, ( + f"Cam_1 x: expected ~2.0 or ~0.0, got {pos_cam1[0]} — write leaked to wrong prim!" + )