diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index a378a2333e0f..4812e93c7f97 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -131,6 +131,7 @@ The following modules are available in the ``isaaclab_physx`` extension: sensors sim.schemas sim.spawners + sim.views .. toctree:: :hidden: @@ -142,6 +143,7 @@ The following modules are available in the ``isaaclab_physx`` extension: lab_physx/isaaclab_physx.sensors lab_physx/isaaclab_physx.sim.schemas lab_physx/isaaclab_physx.sim.spawners + lab_physx/isaaclab_physx.sim.views isaaclab_newton extension ------------------------- diff --git a/docs/source/api/lab/isaaclab.utils.rst b/docs/source/api/lab/isaaclab.utils.rst index 5b352152e0b5..bf9291c25b3a 100644 --- a/docs/source/api/lab/isaaclab.utils.rst +++ b/docs/source/api/lab/isaaclab.utils.rst @@ -188,3 +188,16 @@ Warp operations :members: :imported-members: :show-inheritance: + +Warp Fabric kernels +^^^^^^^^^^^^^^^^^^^ + +Warp kernels for reading and writing Fabric ``Matrix4d`` attributes +(``omni:fabric:worldMatrix`` / ``omni:fabric:localMatrix``) via +:class:`wp.fabricarray` and :class:`wp.indexedfabricarray`. Used by +:class:`~isaaclab_physx.sim.views.FabricFrameView` to keep child world and +local matrices consistent without round-tripping through USD. + +.. automodule:: isaaclab.utils.warp.fabric + :members: + :show-inheritance: diff --git a/docs/source/api/lab_physx/isaaclab_physx.sim.views.rst b/docs/source/api/lab_physx/isaaclab_physx.sim.views.rst new file mode 100644 index 000000000000..4ee8e9e96745 --- /dev/null +++ b/docs/source/api/lab_physx/isaaclab_physx.sim.views.rst @@ -0,0 +1,17 @@ +isaaclab\_physx.sim.views +========================= + +.. automodule:: isaaclab_physx.sim.views + + .. rubric:: Classes + + .. autosummary:: + + FabricFrameView + +Fabric Frame View +----------------- + +.. autoclass:: FabricFrameView + :members: + :show-inheritance: diff --git a/scripts/benchmarks/benchmark_view_comparison.py b/scripts/benchmarks/benchmark_view_comparison.py index a637f687803e..80051f555d78 100644 --- a/scripts/benchmarks/benchmark_view_comparison.py +++ b/scripts/benchmarks/benchmark_view_comparison.py @@ -271,26 +271,71 @@ def _run_pose_benchmarks( positions: wp.array, orientations: wp.array, ): - """Shared benchmark loop for get/set world poses on any FrameView.""" + """Shared benchmark loop for get/set {world,local} poses on any FrameView.""" + + # FrameView getters now return ProxyArray; older callers worked with wp.array + # directly. Support both transparently. + def _as_wp(a): + return a.warp if hasattr(a, "warp") else a + + positions_wp = _as_wp(positions) + orientations_wp = _as_wp(orientations) + start_time = time.perf_counter() for _ in range(num_iterations): view.get_world_poses() timing_results["get_world_poses"] = (time.perf_counter() - start_time) / num_iterations - new_positions = wp.clone(positions) + new_positions = wp.clone(positions_wp) new_positions_t = wp.to_torch(new_positions) new_positions_t[:, 2] += 0.5 expected_positions = new_positions_t.clone() start_time = time.perf_counter() for _ in range(num_iterations): - view.set_world_poses(new_positions, orientations) + view.set_world_poses(new_positions, orientations_wp) timing_results["set_world_poses"] = (time.perf_counter() - start_time) / num_iterations + # Interleaved set→get on world poses — the realistic write/read pattern for + # downstream consumers (e.g. cameras updating their pose then immediately + # querying it). + start_time = time.perf_counter() + for _ in range(num_iterations): + view.set_world_poses(new_positions, orientations_wp) + view.get_world_poses() + timing_results["interleaved_world"] = (time.perf_counter() - start_time) / num_iterations + + # Local poses — Fabric-aware path on FabricFrameView, USD path otherwise. + if hasattr(view, "get_local_poses"): + start_time = time.perf_counter() + for _ in range(num_iterations): + view.get_local_poses() + timing_results["get_local_poses"] = (time.perf_counter() - start_time) / num_iterations + + if hasattr(view, "set_local_poses"): + local_pos, local_ori = view.get_local_poses() + local_pos_t = ( + local_pos.torch + if hasattr(local_pos, "torch") + else (wp.to_torch(local_pos) if isinstance(local_pos, wp.array) else local_pos) + ) + local_ori_t = ( + local_ori.torch + if hasattr(local_ori, "torch") + else (wp.to_torch(local_ori) if isinstance(local_ori, wp.array) else local_ori) + ) + new_local_pos = wp.from_torch(local_pos_t.clone().contiguous()) + new_local_ori = wp.from_torch(local_ori_t.clone().contiguous()) + + start_time = time.perf_counter() + for _ in range(num_iterations): + view.set_local_poses(translations=new_local_pos, orientations=new_local_ori) + timing_results["set_local_poses"] = (time.perf_counter() - start_time) / num_iterations + ret_pos, ret_quat = view.get_world_poses() - ret_pos_t = wp.to_torch(ret_pos) - ret_quat_t = wp.to_torch(ret_quat) - ori_t = wp.to_torch(orientations) + ret_pos_t = ret_pos.torch if hasattr(ret_pos, "torch") else wp.to_torch(ret_pos) + ret_quat_t = ret_quat.torch if hasattr(ret_quat, "torch") else wp.to_torch(ret_quat) + ori_t = wp.to_torch(orientations_wp) pos_ok = torch.allclose(ret_pos_t, expected_positions, atol=1e-4, rtol=0) quat_ok = torch.allclose(ret_quat_t, ori_t, atol=1e-4, rtol=0) @@ -327,6 +372,9 @@ def print_results(results_dict: dict[str, dict[str, float]], num_prims: int, num ("Initialization", "init"), ("Get World Poses", "get_world_poses"), ("Set World Poses", "set_world_poses"), + ("Interleaved Set->Get", "interleaved_world"), + ("Get Local Poses", "get_local_poses"), + ("Set Local Poses", "set_local_poses"), ] for op_name, op_key in operations: diff --git a/source/isaaclab/changelog.d/fix-fabric-local-matrix.rst b/source/isaaclab/changelog.d/fix-fabric-local-matrix.rst new file mode 100644 index 000000000000..a048992a809d --- /dev/null +++ b/source/isaaclab/changelog.d/fix-fabric-local-matrix.rst @@ -0,0 +1,24 @@ +Added +^^^^^ + +* Added :func:`~isaaclab.utils.warp.fabric.decompose_indexed_fabric_transforms` + and :func:`~isaaclab.utils.warp.fabric.compose_indexed_fabric_transforms` + Warp kernels. They mirror the existing + ``decompose_fabric_transformation_matrix_to_warp_arrays`` / + ``compose_fabric_transformation_matrix_from_warp_arrays`` kernels but + operate on :class:`wp.indexedfabricarray`, so the view-to-fabric mapping + is baked into the array and the kernel just dereferences + ``ifa[view_index]`` instead of taking a separate ``mapping`` argument. +* Added :func:`~isaaclab.utils.warp.fabric.update_indexed_local_matrix_from_world` + and :func:`~isaaclab.utils.warp.fabric.update_indexed_world_matrix_from_local` + Warp kernels that propagate ``local = world * inv(parent)`` and + ``world = local * parent`` directly on Fabric storage matrices (no + explicit transposes). Used by + :class:`~isaaclab_physx.sim.views.FabricFrameView` to keep child world and + local matrices consistent across writes without round-tripping through USD. +* Added :meth:`~isaaclab.sim.SimulationContext.get_service` and + :meth:`~isaaclab.sim.SimulationContext.set_service` — a typed singleton + service locator on :class:`~isaaclab.sim.SimulationContext`. Backend-specific + caches (e.g. Fabric hierarchy handles) register themselves here instead of + living as class-level globals. Services are automatically cleared on + :meth:`~isaaclab.sim.SimulationContext.clear_instance`. diff --git a/source/isaaclab/isaaclab/sim/simulation_context.py b/source/isaaclab/isaaclab/sim/simulation_context.py index 175961fcd383..211155b065b1 100644 --- a/source/isaaclab/isaaclab/sim/simulation_context.py +++ b/source/isaaclab/isaaclab/sim/simulation_context.py @@ -11,7 +11,7 @@ import traceback from collections.abc import Iterator from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar import toml import torch @@ -38,6 +38,8 @@ if TYPE_CHECKING: from isaaclab.cloner.clone_plan import ClonePlan +_T = TypeVar("_T") + from .simulation_cfg import SimulationCfg from .spawners import DomeLightCfg, GroundPlaneCfg @@ -214,6 +216,11 @@ def __init__(self, cfg: SimulationCfg | None = None): order=5, ) + # Singleton service registry — backend-specific caches (e.g. Fabric hierarchy) + # register themselves here, keyed by their class. All services are cleared when + # the SimulationContext is torn down via clear_instance(). + self._services: dict[type, object] = {} + type(self)._instance = self # Mark as valid singleton only after successful init def _apply_render_cfg_settings(self) -> None: @@ -852,6 +859,32 @@ def get_setting(self, name: str) -> Any: """Get a setting value.""" return self._settings_helper.get(name) + def get_service(self, cls: type[_T]) -> _T | None: + """Retrieve a registered singleton service by its class. + + Args: + cls: The service class used as key. + + Returns: + The registered instance, or ``None`` if not registered. + """ + return self._services.get(cls) # type: ignore[return-value] + + def set_service(self, cls: type[_T], instance: _T) -> None: + """Register a singleton service, keyed by its class. + + Overwrites any previously registered instance for the same class. + The service is automatically cleared when :meth:`clear_instance` is called. + + Args: + cls: The service class used as key. + instance: The service instance to register. + """ + old = self._services.get(cls) + if old is not None and old is not instance and hasattr(old, "close"): + old.close() + self._services[cls] = instance + @classmethod def clear_instance(cls) -> None: """Clean up resources and clear the singleton instance.""" @@ -865,6 +898,12 @@ def clear_instance(cls) -> None: viz.close() cls._instance._visualizers.clear() + # Close and drop all registered singleton services + for service in cls._instance._services.values(): + if hasattr(service, "close"): + service.close() + cls._instance._services.clear() + # Tear down the stage. We skip clear_stage() (prim-by-prim deletion) since # close_stage() + app shutdown destroy the entire stage at once. stage_utils.close_stage() diff --git a/source/isaaclab/isaaclab/utils/warp/fabric.py b/source/isaaclab/isaaclab/utils/warp/fabric.py index a48f773f4991..6f9963f290a1 100644 --- a/source/isaaclab/isaaclab/utils/warp/fabric.py +++ b/source/isaaclab/isaaclab/utils/warp/fabric.py @@ -15,15 +15,28 @@ import warp as wp +__all__ = [ + "arange_k", + "compose_fabric_transformation_matrix_from_warp_arrays", + "compose_indexed_fabric_transforms", + "decompose_fabric_transformation_matrix_to_warp_arrays", + "decompose_indexed_fabric_transforms", + "set_view_to_fabric_array", + "update_indexed_local_matrix_from_world", + "update_indexed_world_matrix_from_local", +] + if TYPE_CHECKING: FabricArrayUInt32 = Any FabricArrayMat44d = Any + IndexedFabricArrayMat44d = Any ArrayUInt32 = Any ArrayUInt32_1d = Any ArrayFloat32_2d = Any else: FabricArrayUInt32 = wp.fabricarray(dtype=wp.uint32) FabricArrayMat44d = wp.fabricarray(dtype=wp.mat44d) + IndexedFabricArrayMat44d = wp.indexedfabricarray(dtype=wp.mat44d) ArrayUInt32 = wp.array(ndim=1, dtype=wp.uint32) ArrayUInt32_1d = wp.array(dtype=wp.uint32) ArrayFloat32_2d = wp.array(ndim=2, dtype=wp.float32) @@ -130,29 +143,20 @@ def compose_fabric_transformation_matrix_from_warp_arrays( position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[fabric_index])) # update position (check if array has elements, not just if it exists) if array_positions.shape[0] > 0: - if broadcast_positions: - index = 0 - else: - index = i + index = wp.where(broadcast_positions, 0, i) position[0] = array_positions[index, 0] position[1] = array_positions[index, 1] position[2] = array_positions[index, 2] # update orientation (convert from wxyz to xyzw for Warp) if array_orientations.shape[0] > 0: - if broadcast_orientations: - index = 0 - else: - index = i + index = wp.where(broadcast_orientations, 0, i) rotation[0] = array_orientations[index, 0] # x rotation[1] = array_orientations[index, 1] # y rotation[2] = array_orientations[index, 2] # z rotation[3] = array_orientations[index, 3] # w # update scale if array_scales.shape[0] > 0: - if broadcast_scales: - index = 0 - else: - index = i + index = wp.where(broadcast_scales, 0, i) scale[0] = array_scales[index, 0] scale[1] = array_scales[index, 1] scale[2] = array_scales[index, 2] @@ -163,6 +167,167 @@ def compose_fabric_transformation_matrix_from_warp_arrays( ) +@wp.kernel(enable_backward=False) +def decompose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + indices: ArrayUInt32, +): + """Decompose indexed Fabric transformation matrices into position, orientation, and scale. + + Like :func:`decompose_fabric_transformation_matrix_to_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices. + array_positions: Output array for positions [m], shape (N, 3). + array_orientations: Output array for quaternions in xyzw format, shape (N, 4). + array_scales: Output array for scales, shape (N, 3). + indices: View indices to process (subset selection). + """ + output_index = wp.tid() + view_index = indices[output_index] + + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + array_positions[output_index, 0] = position[0] + array_positions[output_index, 1] = position[1] + array_positions[output_index, 2] = position[2] + if array_orientations.shape[0] > 0: + array_orientations[output_index, 0] = rotation[0] + array_orientations[output_index, 1] = rotation[1] + array_orientations[output_index, 2] = rotation[2] + array_orientations[output_index, 3] = rotation[3] + if array_scales.shape[0] > 0: + array_scales[output_index, 0] = scale[0] + array_scales[output_index, 1] = scale[1] + array_scales[output_index, 2] = scale[2] + + +@wp.kernel(enable_backward=False) +def compose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + broadcast_positions: bool, + broadcast_orientations: bool, + broadcast_scales: bool, + indices: ArrayUInt32, +): + """Compose indexed Fabric transformation matrices from position, orientation, and scale. + + Like :func:`compose_fabric_transformation_matrix_from_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices to update. + array_positions: Input array for positions [m], shape (N, 3). + array_orientations: Input array for quaternions in xyzw format, shape (N, 4). + array_scales: Input array for scales, shape (N, 3). + broadcast_positions: If True, use first position for all prims. + broadcast_orientations: If True, use first orientation for all prims. + broadcast_scales: If True, use first scale for all prims. + indices: View indices to process (subset selection). + """ + i = wp.tid() + view_index = indices[i] + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + index = wp.where(broadcast_positions, 0, i) + position[0] = array_positions[index, 0] + position[1] = array_positions[index, 1] + position[2] = array_positions[index, 2] + if array_orientations.shape[0] > 0: + index = wp.where(broadcast_orientations, 0, i) + rotation[0] = array_orientations[index, 0] + rotation[1] = array_orientations[index, 1] + rotation[2] = array_orientations[index, 2] + rotation[3] = array_orientations[index, 3] + if array_scales.shape[0] > 0: + index = wp.where(broadcast_scales, 0, i) + scale[0] = array_scales[index, 0] + scale[1] = array_scales[index, 1] + scale[2] = array_scales[index, 2] + + fabric_matrices[view_index] = wp.mat44d( # type: ignore[arg-type] + wp.transpose(wp.transform_compose(position, rotation, scale)) # type: ignore[arg-type] + ) + + +@wp.kernel(enable_backward=False) +def update_indexed_local_matrix_from_world( + child_world_matrices: IndexedFabricArrayMat44d, + parent_world_matrices: IndexedFabricArrayMat44d, + child_local_matrices: IndexedFabricArrayMat44d, + indices: ArrayUInt32, +): + """Recompute child localMatrix from (parent worldMatrix, child worldMatrix). + + Computes ``child_local = inv(parent_world) * child_world`` per prim and writes the + result back to the child's :data:`omni:fabric:localMatrix` so that subsequent + ``get_local_poses`` calls see consistent values after a world-pose write. + + All three indexed arrays are expected to be indexed by the same per-view indices + (i.e. ``view_to_child_fabric``, ``view_to_parent_fabric``, ``view_to_child_fabric``) + so the kernel only needs the view-side indices. + + Storage convention: Fabric matrices are stored as the transpose of the standard + column-major math convention. Math is ``local = inv(parent) * world``; under + the transpose identity ``(A * B)^T = B^T * A^T`` (and ``inv(A^T) = inv(A)^T``) + that is equivalent to storage-side ``local^T = world^T * inv(parent^T)``, so we + can compute it directly on the stored matrices without explicit transposes. + + Args: + child_world_matrices: Indexed fabric array of child world matrices (read). + parent_world_matrices: Indexed fabric array of parent world matrices (read). + child_local_matrices: Indexed fabric array of child local matrices (written). + indices: View indices to process. + """ + i = wp.tid() + view_index = indices[i] + child_world = wp.mat44f(child_world_matrices[view_index]) + parent_world = wp.mat44f(parent_world_matrices[view_index]) + child_local_matrices[view_index] = wp.mat44d(child_world * wp.inverse(parent_world)) # type: ignore[arg-type] + + +@wp.kernel(enable_backward=False) +def update_indexed_world_matrix_from_local( + child_local_matrices: IndexedFabricArrayMat44d, + parent_world_matrices: IndexedFabricArrayMat44d, + child_world_matrices: IndexedFabricArrayMat44d, + indices: ArrayUInt32, +): + """Recompute child worldMatrix from (parent worldMatrix, child localMatrix). + + Computes ``child_world = parent_world * child_local`` per prim and writes the + result back to the child's :data:`omni:fabric:worldMatrix`. Used after a + ``set_local_poses`` write so that subsequent ``get_world_poses`` calls see + consistent values. Mirror of :func:`update_indexed_local_matrix_from_world`. + + Args: + child_local_matrices: Indexed fabric array of child local matrices (read). + parent_world_matrices: Indexed fabric array of parent world matrices (read). + child_world_matrices: Indexed fabric array of child world matrices (written). + indices: View indices to process. + + Storage convention: same as :func:`update_indexed_local_matrix_from_world`. + Math is ``world = parent * local``; under the transpose identity that becomes + storage-side ``world^T = local^T * parent^T``, no explicit transposes needed. + """ + i = wp.tid() + view_index = indices[i] + child_local = wp.mat44f(child_local_matrices[view_index]) + parent_world = wp.mat44f(parent_world_matrices[view_index]) + child_world_matrices[view_index] = wp.mat44d(child_local * parent_world) # type: ignore[arg-type] + + @wp.func def _decompose_transformation_matrix(m: Any): # -> tuple[wp.vec3f, wp.quatf, wp.vec3f] """Decompose a 4x4 transformation matrix into position, orientation, and scale. diff --git a/source/isaaclab_physx/changelog.d/fix-fabric-local-matrix.rst b/source/isaaclab_physx/changelog.d/fix-fabric-local-matrix.rst new file mode 100644 index 000000000000..a99a5fe19a5c --- /dev/null +++ b/source/isaaclab_physx/changelog.d/fix-fabric-local-matrix.rst @@ -0,0 +1,28 @@ +Fixed +^^^^^ + +* Fixed :meth:`~isaaclab_physx.sim.views.FabricFrameView.get_local_poses` + returning stale USD values after Fabric world-pose writes. Local poses + are now read directly from Fabric's ``omni:fabric:localMatrix`` via + :class:`wp.indexedfabricarray`, and are kept consistent with worldMatrix + through Warp kernels that propagate either direction on writes. + +Changed +^^^^^^^ + +* Reworked :class:`~isaaclab_physx.sim.views.FabricFrameView` to use three + persistent ``PrimSelection`` instances (one per access mode), path-based + view → fabric index mapping (no custom prim attributes), and Warp kernels + that operate on :class:`wp.indexedfabricarray` so the kernels just index + ``ifa[view_index]`` instead of taking a separate mapping array. +* Moved the ``IFabricHierarchy`` handle cache out of ``FabricFrameView`` (class-level + global) into a new :class:`~isaaclab_physx.sim.fabric_stage_cache.FabricStageCache`, + registered as a service on :class:`~isaaclab.sim.SimulationContext`. The cache is + automatically cleared on stage teardown. +* :meth:`~isaaclab_physx.sim.views.FabricFrameView.set_local_poses` now + writes ``omni:fabric:localMatrix`` directly through Fabric. The next + ``get_world_poses`` runs a Warp kernel that recomputes + ``child_world = parent_world * child_local``. Symmetrically, + ``set_world_poses`` runs a kernel that recomputes + ``child_local = inv(parent_world) * child_world`` so subsequent + ``get_local_poses`` calls return consistent values. diff --git a/source/isaaclab_physx/isaaclab_physx/sim/fabric_stage_cache.py b/source/isaaclab_physx/isaaclab_physx/sim/fabric_stage_cache.py new file mode 100644 index 000000000000..22386b858b5c --- /dev/null +++ b/source/isaaclab_physx/isaaclab_physx/sim/fabric_stage_cache.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Fabric stage and hierarchy cache, registered as a service on SimulationContext.""" + +from __future__ import annotations + +from pxr import UsdUtils + + +class FabricStageCache: + """Caches the usdrt stage attachment and IFabricHierarchy handles. + + Registered as a singleton service on :class:`~isaaclab.sim.SimulationContext` via + ``set_service(FabricStageCache, ...)``. Multiple + :class:`~isaaclab_physx.sim.views.FabricFrameView` instances share a single + hierarchy handle per Fabric attachment. + + The hierarchy cache is keyed by ``fabric_id_int`` (the stable ``.id`` integer from + ``FabricId``). Currently Isaac Lab always has exactly one Fabric attachment per + stage, so this dict will hold at most one entry. A dict is used rather than a plain + attribute so the design naturally extends to multi-Fabric scenarios (e.g. multi-GPU + support, where each GPU gets its own Fabric attachment) without an API change. + """ + + def __init__(self, usd_stage) -> None: + import usdrt # noqa: PLC0415 + + stage_id = UsdUtils.StageCache.Get().GetId(usd_stage).ToLongInt() + self._stage = usdrt.Usd.Stage.Attach(stage_id) + self._stage.SynchronizeToFabric() + self._hierarchy_cache: dict[int, object] = {} + + @property + def stage(self): + """The usdrt stage (already attached and synchronized).""" + return self._stage + + def close(self) -> None: + """Release cached handles. Called by SimulationContext on teardown.""" + self._hierarchy_cache.clear() + self._stage = None + + def get_hierarchy(self): + """Return the IFabricHierarchy handle for the current Fabric attachment. + + Creates and caches the handle on first call. Change-tracking is enabled + for both local and world xforms. + + Returns: + A tuple of ``(hierarchy_handle, fabric_id_int)``. + """ + import usdrt # noqa: PLC0415 + + fabric_id = self._stage.GetFabricId() + fabric_id_int = fabric_id.id + + if fabric_id_int not in self._hierarchy_cache: + hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( + fabric_id, self._stage.GetStageIdAsStageId() + ) + hierarchy.track_local_xform_changes(True) + hierarchy.track_world_xform_changes(True) + self._hierarchy_cache[fabric_id_int] = hierarchy + + return self._hierarchy_cache[fabric_id_int], fabric_id_int diff --git a/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py b/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py index 1bcff86d57ac..c4274f1f69fb 100644 --- a/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py +++ b/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py @@ -12,7 +12,7 @@ import torch import warp as wp -from pxr import Usd +from pxr import Gf, Usd, UsdGeom import isaaclab.sim as sim_utils from isaaclab.app.settings_manager import SettingsManager @@ -49,23 +49,43 @@ def _to_float32_2d(a: wp.array | torch.Tensor) -> wp.array | torch.Tensor: class FabricFrameView(BaseFrameView): """FrameView with Fabric GPU acceleration for the PhysX backend. - Uses composition: holds a :class:`UsdFrameView` internally for USD - fallback and non-accelerated operations (local poses, visibility, scales - when Fabric is disabled). - - When Fabric is enabled, world-pose and scale operations use Warp kernels - operating on ``omni:fabric:worldMatrix``. All other operations delegate - to the internal USD view. - - After every Fabric write (``set_world_poses``, ``set_scales``), - :meth:`PrepareForReuse` is called on the ``PrimSelection`` to notify - the FSD renderer that Fabric data has changed and to detect topology - changes that require rebuilding internal mappings. Read operations - do not call PrepareForReuse to avoid unnecessary renderer invalidation. - - Pose getters return :class:`~isaaclab.utils.warp.ProxyArray`. Setters accept ``wp.array``. + World-pose, local-pose, and scale operations run on the GPU via Warp + kernels that read and write ``omni:fabric:worldMatrix`` and + ``omni:fabric:localMatrix`` directly. Typical speedup vs. the + :class:`~isaaclab.sim.views.UsdFrameView` baseline at 1024 prims is + 150-260× per call (see ``scripts/benchmarks/benchmark_view_comparison.py``). + + When Fabric is unavailable — ``/physics/fabricEnabled`` is false or the + device is unsupported — the view transparently falls back to + :class:`~isaaclab.sim.views.UsdFrameView` for all pose and scale + operations. The ``count``, ``prims``, ``prim_paths`` properties and the + ``get_visibility`` / ``set_visibility`` methods always delegate to + :class:`~isaaclab.sim.views.UsdFrameView`; Fabric has no equivalent fast + path for those. + + Behavior: + + * **No write-back to USD.** Fabric writes update only + ``omni:fabric:worldMatrix`` / ``omni:fabric:localMatrix``; the prim's + USD ``xformOp:*`` attributes are unchanged. Downstream consumers that + read the prim's USD attributes after a Fabric write will see stale + values until the next USD-side sync. + * **World ↔ local consistency.** After ``set_world_poses`` (or + ``set_scales``) the local matrix is updated so that subsequent + ``get_local_poses`` is consistent; after ``set_local_poses`` the world + matrix is recomputed on the next world read. Both directions stay in + sync without round-tripping through USD. + * **Topology-adaptive.** Fabric topology changes are detected on each + access; the view rebuilds its internal mapping automatically and no + manual refresh is required. Steady-state overhead is negligible. + + Pose getters return :class:`~isaaclab.utils.warp.ProxyArray`; setters + accept :class:`wp.array`. """ + _WORLD_MATRIX_NAME = "omni:fabric:worldMatrix" + _LOCAL_MATRIX_NAME = "omni:fabric:localMatrix" + def __init__( self, prim_path: str, @@ -101,14 +121,44 @@ def __init__( ) self._use_fabric = False + # Fabric state — all populated lazily in :meth:`_initialize_fabric`. self._fabric_initialized = False - self._fabric_usd_sync_done = False - self._fabric_selection = None - self._fabric_to_view: wp.array | None = None - self._view_to_fabric: wp.array | None = None - self._default_view_indices: wp.array | None = None + self._stage = None self._fabric_hierarchy = None - self._view_index_attr = f"isaaclab:view_index:{abs(hash(self))}" + # Set by ``set_local_poses``; cleared by ``_sync_world_from_local_if_dirty``. + # Per-view (not per-stage) so concurrent views on the same stage don't clear + # each other's flag. + self._world_dirty: bool = False + + # Selections. + self._trans_sel_ro = None + self._world_sel_rw = None + self._local_sel_rw = None + + # Index arrays (view-side indices and view→fabric mappings). Each selection's + # ``GetPaths()`` ordering is independent, so the view→fabric mapping is cached + # per selection rather than shared — sharing would silently corrupt indexed + # arrays whose selection didn't fire ``PrepareForReuse`` on the same frame. + self._view_indices: wp.array | None = None + self._trans_ro_fabric_indices: wp.array | None = None + self._world_rw_fabric_indices: wp.array | None = None + self._local_rw_fabric_indices: wp.array | None = None + self._parent_fabric_indices: wp.array | None = None + + # Indexed fabric arrays. + self._world_ifa_ro = None + self._local_ifa_ro = None + self._world_ifa_rw = None + self._local_ifa_rw = None + self._parent_world_ifa_ro = None + + # Sentinel passed to ``compose_indexed_fabric_transforms`` / + # ``decompose_indexed_fabric_transforms`` for slots the caller does not want + # written or read. The kernels gate every per-row access on + # ``shape[0] > 0``, so a ``(0, 0)`` array is enough — the inner dim is never + # indexed. One shared instance covers all "unused" slots regardless of + # whether they would have held positions, quaternions, or scales. + self._fabric_empty_2d_array_sentinel: wp.array | None = None # ------------------------------------------------------------------ # Delegated properties @@ -153,39 +203,33 @@ def set_world_poses(self, positions=None, orientations=None, indices=None): if not self._fabric_initialized: self._initialize_fabric() - self._prepare_for_reuse() + # If a prior set_local_poses left worldMatrix stale, propagate local → world first. + self._sync_world_from_local_if_dirty() indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - dummy = wp.zeros((0, 3), dtype=wp.float32, device=self._device) - positions_wp = _to_float32_2d(positions) if positions is not None else dummy - orientations_wp = ( - _to_float32_2d(orientations) - if orientations is not None - else wp.zeros((0, 4), dtype=wp.float32, device=self._device) - ) + positions_wp = self._to_float32_2d_or_empty(positions) + orientations_wp = self._to_float32_2d_or_empty(orientations) wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], inputs=[ - self._fabric_world_matrices, + self._get_world_rw_array(), positions_wp, orientations_wp, - dummy, + self._fabric_empty_2d_array_sentinel, False, False, False, indices_wp, - self._view_to_fabric, ], - device=self._fabric_device, + device=self._device, ) wp.synchronize() - self._fabric_hierarchy.update_world_xforms() - self._fabric_usd_sync_done = True + # World was just written — recompute child localMatrix from parent worldMatrix + # so the next get_local_poses returns consistent values. + self._sync_local_from_world(indices_wp) def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: if not self._use_fabric: @@ -193,8 +237,9 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, if not self._fabric_initialized: self._initialize_fabric() - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() + + # If a prior set_local_poses left worldMatrix stale, propagate local → world first. + self._sync_world_from_local_if_dirty() indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -208,17 +253,16 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, orientations_wp = wp.zeros((count, 4), dtype=wp.float32, device=self._device) wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + kernel=fabric_utils.decompose_indexed_fabric_transforms, dim=count, inputs=[ - self._fabric_world_matrices, + self._get_world_ro_array(), positions_wp, orientations_wp, - self._fabric_dummy_buffer, + self._fabric_empty_2d_array_sentinel, indices_wp, - self._view_to_fabric, ], - device=self._fabric_device, + device=self._device, ) if use_cached: @@ -227,17 +271,79 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, return ProxyArray(positions_wp), ProxyArray(orientations_wp) # ------------------------------------------------------------------ - # Local poses — USD fallback (Fabric only accelerates world poses) + # Local poses # ------------------------------------------------------------------ def set_local_poses(self, translations=None, orientations=None, indices=None): - self._usd_view.set_local_poses(translations, orientations, indices) + if not self._use_fabric: + self._usd_view.set_local_poses(translations, orientations, indices) + return + + if not self._fabric_initialized: + self._initialize_fabric() + + indices_wp = self._resolve_indices_wp(indices) + translations_wp = self._to_float32_2d_or_empty(translations) + orientations_wp = self._to_float32_2d_or_empty(orientations) + + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[ + self._get_local_rw_array(), + translations_wp, + orientations_wp, + self._fabric_empty_2d_array_sentinel, + False, + False, + False, + indices_wp, + ], + device=self._device, + ) + wp.synchronize() + + # Mark this view's worlds stale so the next world read recomputes them. + self._world_dirty = True def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: - return self._usd_view.get_local_poses(indices) + if not self._use_fabric: + return self._usd_view.get_local_poses(indices) + + if not self._fabric_initialized: + self._initialize_fabric() + + indices_wp = self._resolve_indices_wp(indices) + count = indices_wp.shape[0] + + use_cached = indices is None or indices == slice(None) + if use_cached: + translations_wp = self._fabric_local_translations_buf + orientations_wp = self._fabric_local_orientations_buf + else: + translations_wp = wp.zeros((count, 3), dtype=wp.float32, device=self._device) + orientations_wp = wp.zeros((count, 4), dtype=wp.float32, device=self._device) + + wp.launch( + kernel=fabric_utils.decompose_indexed_fabric_transforms, + dim=count, + inputs=[ + self._get_local_ro_array(), + translations_wp, + orientations_wp, + self._fabric_empty_2d_array_sentinel, + indices_wp, + ], + device=self._device, + ) + + if use_cached: + wp.synchronize() + return self._fabric_local_translations_ta, self._fabric_local_orientations_ta + return ProxyArray(translations_wp), ProxyArray(orientations_wp) # ------------------------------------------------------------------ - # Scales — Fabric-accelerated or USD fallback + # Scales # ------------------------------------------------------------------ def set_scales(self, scales, indices=None): @@ -248,35 +354,32 @@ def set_scales(self, scales, indices=None): if not self._fabric_initialized: self._initialize_fabric() - self._prepare_for_reuse() + # Sync world matrices first if local writes are pending. + self._sync_world_from_local_if_dirty() indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - dummy3 = wp.zeros((0, 3), dtype=wp.float32, device=self._device) - dummy4 = wp.zeros((0, 4), dtype=wp.float32, device=self._device) - scales_wp = _to_float32_2d(scales) + scales_wp = self._to_float32_2d_or_empty(scales) wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], inputs=[ - self._fabric_world_matrices, - dummy3, - dummy4, + self._get_world_rw_array(), + self._fabric_empty_2d_array_sentinel, + self._fabric_empty_2d_array_sentinel, scales_wp, False, False, False, indices_wp, - self._view_to_fabric, ], - device=self._fabric_device, + device=self._device, ) wp.synchronize() - self._fabric_hierarchy.update_world_xforms() - self._fabric_usd_sync_done = True + # World was just written — recompute child localMatrix from parent worldMatrix + # so the next get_local_poses returns the new scale rather than the stale one. + self._sync_local_from_world(indices_wp) def get_scales(self, indices=None): if not self._use_fabric: @@ -284,8 +387,9 @@ def get_scales(self, indices=None): if not self._fabric_initialized: self._initialize_fabric() - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() + + # Sync world matrices first if local writes are pending. + self._sync_world_from_local_if_dirty() indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -297,17 +401,16 @@ def get_scales(self, indices=None): scales_wp = wp.zeros((count, 3), dtype=wp.float32, device=self._device) wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + kernel=fabric_utils.decompose_indexed_fabric_transforms, dim=count, inputs=[ - self._fabric_world_matrices, - self._fabric_dummy_buffer, - self._fabric_dummy_buffer, + self._get_world_ro_array(), + self._fabric_empty_2d_array_sentinel, + self._fabric_empty_2d_array_sentinel, scales_wp, indices_wp, - self._view_to_fabric, ], - device=self._fabric_device, + device=self._device, ) if use_cached: @@ -315,153 +418,402 @@ def get_scales(self, indices=None): return scales_wp # ------------------------------------------------------------------ - # Internal — PrepareForReuse (renderer notification + topology tracking) + # Internal — sync helpers # ------------------------------------------------------------------ - def _prepare_for_reuse(self) -> None: - """Call PrepareForReuse on the PrimSelection to notify the renderer. + def _to_float32_2d_or_empty(self, data): + return self._fabric_empty_2d_array_sentinel if data is None else _to_float32_2d(data) - PrepareForReuse serves two purposes: + def _sync_world_from_local_if_dirty(self) -> None: + """If a prior local write left world matrices stale, recompute them on the fly. - 1. **Renderer notification**: Tells FSD/Storm that Fabric data has - been (or will be) modified, so the next rendered frame reflects - the updated transforms. - 2. **Topology change detection**: Returns True when Fabric's - internal memory layout changed (e.g., prims added/removed). - In that case, view-to-fabric index mappings and fabricarrays - must be rebuilt. + We deliberately do NOT call ``IFabricHierarchy.update_world_xforms()`` — + in practice that re-reads USD's authored xformOps and overwrites the Fabric + local+world matrices we just authored. Instead we fire a Warp kernel that + does ``child_world = parent_world * child_local`` per child, leaving the + Fabric-side localMatrix untouched. """ - if self._fabric_selection is None: + if not self._world_dirty: return + # Refresh trans_sel_ro once, then read _local_ifa_ro and _parent_world_ifa_ro + # directly to avoid calling PrepareForReuse twice on the same selection. + if self._trans_sel_ro.PrepareForReuse() or self._parent_world_ifa_ro is None: + self._rebuild_trans_ro_arrays() + wp.launch( + kernel=fabric_utils.update_indexed_world_matrix_from_local, + dim=self.count, + inputs=[ + self._local_ifa_ro, + self._parent_world_ifa_ro, + self._get_world_rw_array(), + self._view_indices, + ], + device=self._device, + ) + wp.synchronize() + self._world_dirty = False - topology_changed = self._fabric_selection.PrepareForReuse() - if topology_changed: - logger.info("Fabric topology changed — rebuilding view-to-fabric index mapping.") - self._rebuild_fabric_arrays() - - def _rebuild_fabric_arrays(self) -> None: - """Rebuild fabricarray and view↔fabric mappings after a topology change. + def _sync_local_from_world(self, indices_wp: wp.array) -> None: + """Recompute child ``localMatrix`` from (parent worldMatrix, child worldMatrix). - Note: Only index mappings and fabricarrays are rebuilt. Position/orientation/scale - buffers are *not* resized because ``self.count`` is derived from the USD prim-path - pattern (via ``_usd_view.count``) and does not change when Fabric rearranges its - internal memory layout. The assertion below guards this invariant. + Called after ``set_world_poses`` so that subsequent ``get_local_poses`` returns + values consistent with the just-written world poses. Fabric Hierarchy does + not provide a built-in world → local sync, so we do it via a Warp kernel + using the parent indexed fabric array. """ - assert self.count == self._default_view_indices.shape[0], ( - f"Prim count changed ({self.count} vs {self._default_view_indices.shape[0]}). " - "Fabric topology change added/removed tracked prims — full re-initialization required." - ) - self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32, device=self._fabric_device) - self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr) - + # Refresh trans_sel_ro once; _world_ifa_ro and _parent_world_ifa_ro share it. + if self._trans_sel_ro.PrepareForReuse() or self._parent_world_ifa_ro is None: + self._rebuild_trans_ro_arrays() wp.launch( - kernel=fabric_utils.set_view_to_fabric_array, - dim=self._fabric_to_view.shape[0], - inputs=[self._fabric_to_view, self._view_to_fabric], - device=self._fabric_device, + kernel=fabric_utils.update_indexed_local_matrix_from_world, + dim=indices_wp.shape[0], + inputs=[ + self._world_ifa_ro, + self._parent_world_ifa_ro, + self._get_local_rw_array(), + indices_wp, + ], + device=self._device, ) wp.synchronize() - self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") + # ------------------------------------------------------------------ + # Internal — selection accessors with on-demand index rebuild + # ------------------------------------------------------------------ + + def _get_world_ro_array(self): + if self._trans_sel_ro.PrepareForReuse(): + self._rebuild_trans_ro_arrays() + return self._world_ifa_ro + + def _get_local_ro_array(self): + if self._trans_sel_ro.PrepareForReuse(): + self._rebuild_trans_ro_arrays() + return self._local_ifa_ro + + def _get_world_rw_array(self): + if self._world_sel_rw.PrepareForReuse(): + self._world_rw_fabric_indices = self._compute_fabric_indices(self._world_sel_rw) + self._world_ifa_rw = self._build_indexed_array( + self._world_sel_rw, self._WORLD_MATRIX_NAME, self._world_rw_fabric_indices + ) + return self._world_ifa_rw + + def _get_local_rw_array(self): + if self._local_sel_rw.PrepareForReuse(): + self._local_rw_fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + self._local_ifa_rw = self._build_indexed_array( + self._local_sel_rw, self._LOCAL_MATRIX_NAME, self._local_rw_fabric_indices + ) + return self._local_ifa_rw + + def _get_parent_world_ro_array(self): + # Built and refreshed alongside the trans_ro selection (parents share that selection). + if self._parent_world_ifa_ro is None or self._trans_sel_ro.PrepareForReuse(): + self._rebuild_trans_ro_arrays() + return self._parent_world_ifa_ro + + def _rebuild_trans_ro_arrays(self) -> None: + """Rebuild the trans_ro indices and the three indexed arrays that depend on them. + + ``_world_ifa_ro``, ``_local_ifa_ro`` and ``_parent_world_ifa_ro`` are all + keyed off the ``trans_sel_ro`` path ordering, so they are refreshed together. + """ + self._trans_ro_fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) + self._world_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._WORLD_MATRIX_NAME, self._trans_ro_fabric_indices + ) + self._local_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._LOCAL_MATRIX_NAME, self._trans_ro_fabric_indices + ) + self._parent_world_ifa_ro = self._build_parent_indexed_array(self._trans_sel_ro) + + # ------------------------------------------------------------------ + # Internal — index computation + # ------------------------------------------------------------------ + + def _compute_fabric_indices(self, selection) -> wp.array: + fabric_paths = selection.GetPaths() + path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} + indices: list[int] = [] + for prim_path in self.prim_paths: + fabric_idx = path_to_fabric_idx.get(prim_path) + if fabric_idx is None: + raise RuntimeError( + f"Prim '{prim_path}' not found in Fabric selection. Ensure the hierarchy has been populated." + ) + indices.append(fabric_idx) + return wp.array(indices, dtype=wp.int32, device=self._device) + + def _compute_parent_fabric_indices(self, selection) -> wp.array: + """For each child in this view, look up the parent prim's fabric index.""" + fabric_paths = selection.GetPaths() + path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} + indices: list[int] = [] + for prim_path in self.prim_paths: + parent_path = prim_path.rsplit("/", 1)[0] + if parent_path == "": + raise RuntimeError( + f"Child prim '{prim_path}' is at stage root and has no parent prim. " + "FabricFrameView requires every prim to have a non-pseudoroot parent " + "with Fabric world+local matrices." + ) + fabric_idx = path_to_fabric_idx.get(parent_path) + if fabric_idx is None: + raise RuntimeError( + f"Parent prim '{parent_path}' (for child '{prim_path}') not found in Fabric selection. " + "Ensure parents have Fabric world+local matrices populated." + ) + indices.append(fabric_idx) + return wp.array(indices, dtype=wp.int32, device=self._device) + + def _build_indexed_array(self, selection, attribute_name: str, fabric_indices: wp.array) -> wp.indexedfabricarray: + fa = wp.fabricarray(selection, attribute_name) + return wp.indexedfabricarray(fa=fa, indices=fabric_indices) + + def _build_parent_indexed_array(self, selection) -> wp.indexedfabricarray: + self._parent_fabric_indices = self._compute_parent_fabric_indices(selection) + fa = wp.fabricarray(selection, self._WORLD_MATRIX_NAME) + return wp.indexedfabricarray(fa=fa, indices=self._parent_fabric_indices) + + def _resolve_indices_wp(self, indices: wp.array | None) -> wp.array: + """Resolve view indices as a Warp uint32 array.""" + if indices is None or indices == slice(None): + if self._view_indices is None: + raise RuntimeError("Fabric view indices are not initialized.") + return self._view_indices + if indices.dtype != wp.uint32: + return wp.array(indices.numpy().astype("uint32"), dtype=wp.uint32, device=self._device) + return indices # ------------------------------------------------------------------ # Internal — Fabric initialization # ------------------------------------------------------------------ def _initialize_fabric(self) -> None: - """Initialize Fabric batch infrastructure for GPU-accelerated pose queries.""" + """One-time Fabric setup: hierarchy handle, attribute population, selections, indexed arrays.""" import usdrt # noqa: PLC0415 from usdrt import Rt # noqa: PLC0415 - stage_id = sim_utils.get_current_stage_id() - fabric_stage = usdrt.Usd.Stage.Attach(stage_id) + from isaaclab.sim.simulation_context import SimulationContext # noqa: PLC0415 - for i in range(self.count): - rt_prim = fabric_stage.GetPrimAtPath(self.prim_paths[i]) - rt_xformable = Rt.Xformable(rt_prim) + from isaaclab_physx.sim.fabric_stage_cache import FabricStageCache # noqa: PLC0415 - has_attr = ( - rt_xformable.HasFabricHierarchyWorldMatrixAttr() - if hasattr(rt_xformable, "HasFabricHierarchyWorldMatrixAttr") - else False + sim_context = SimulationContext.instance() + if sim_context is None: + raise RuntimeError( + "FabricFrameView requires an active SimulationContext. " + "Create a SimulationContext before instantiating FabricFrameView." ) - if not has_attr: - rt_xformable.CreateFabricHierarchyWorldMatrixAttr() - - rt_xformable.SetWorldXformFromUsd() - rt_prim.CreateAttribute(self._view_index_attr, usdrt.Sdf.ValueTypeNames.UInt, custom=True) - rt_prim.GetAttribute(self._view_index_attr).Set(i) - - self._fabric_hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( - fabric_stage.GetFabricId(), fabric_stage.GetStageIdAsStageId() + # Get or create the FabricStageCache service. + cache = sim_context.get_service(FabricStageCache) + if cache is None: + cache = FabricStageCache(sim_context.stage) + sim_context.set_service(FabricStageCache, cache) + + self._stage = cache.stage + self._fabric_hierarchy, self._fabric_id = cache.get_hierarchy() + + # Ensure each child prim AND its parent have BOTH Fabric world and local matrix + # attributes. Our ``trans_ro`` selection requires both, so prims missing either + # would silently be excluded. ``Create*Attr`` calls are idempotent. + # + # ``SetWorldXformFromUsd`` writes Fabric's worldMatrix from USD's accumulated + # local-to-world transform (so it picks up the parent chain). + # ``SetLocalXformFromUsd`` writes Fabric's localMatrix from USD's authored + # xformOps on this prim only. Calling both gives Fabric a consistent + # (worldMatrix, localMatrix) pair for each prim before we touch the hierarchy. + seen_paths: set[str] = set() + for child_path in self.prim_paths: + for path in (child_path, child_path.rsplit("/", 1)[0]): + if path in seen_paths: + continue + seen_paths.add(path) + rt_prim = self._stage.GetPrimAtPath(path) + if not rt_prim.IsValid(): + continue + rt_xformable = Rt.Xformable(rt_prim) + rt_xformable.CreateFabricHierarchyWorldMatrixAttr() + rt_xformable.CreateFabricHierarchyLocalMatrixAttr() + rt_xformable.SetLocalXformFromUsd() + rt_xformable.SetWorldXformFromUsd() + + # Three persistent selections — read both, write world, write local. + matrix = usdrt.Sdf.ValueTypeNames.Matrix4d + ro = usdrt.Usd.Access.Read + rw = usdrt.Usd.Access.ReadWrite + wm_ro = (matrix, self._WORLD_MATRIX_NAME, ro) + lm_ro = (matrix, self._LOCAL_MATRIX_NAME, ro) + wm_rw = (matrix, self._WORLD_MATRIX_NAME, rw) + lm_rw = (matrix, self._LOCAL_MATRIX_NAME, rw) + self._trans_sel_ro = self._stage.SelectPrims(require_attrs=[wm_ro, lm_ro], device=self._device, want_paths=True) + self._world_sel_rw = self._stage.SelectPrims(require_attrs=[wm_rw, lm_ro], device=self._device, want_paths=True) + self._local_sel_rw = self._stage.SelectPrims(require_attrs=[wm_ro, lm_rw], device=self._device, want_paths=True) + + # Build the view-side indices array (just [0..count-1]) and a per-selection + # view→fabric mapping (selections do not guarantee a shared path ordering). + self._view_indices = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) + self._trans_ro_fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) + self._world_rw_fabric_indices = self._compute_fabric_indices(self._world_sel_rw) + self._local_rw_fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + + # Indexed fabric arrays per (selection × attribute). + self._world_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._WORLD_MATRIX_NAME, self._trans_ro_fabric_indices ) - self._fabric_hierarchy.update_world_xforms() - - self._default_view_indices = wp.zeros((self.count,), dtype=wp.uint32, device=self._device) - wp.launch( - kernel=fabric_utils.arange_k, dim=self.count, inputs=[self._default_view_indices], device=self._device + self._local_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._LOCAL_MATRIX_NAME, self._trans_ro_fabric_indices ) - wp.synchronize() - - # The constructor should have taken care of this, but double check here to avoid regressions - assert self._device in _fabric_supported_devices - - self._fabric_selection = fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, "omni:fabric:worldMatrix", usdrt.Usd.Access.ReadWrite), - ], - device=self._device, + self._world_ifa_rw = self._build_indexed_array( + self._world_sel_rw, self._WORLD_MATRIX_NAME, self._world_rw_fabric_indices ) - - self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32, device=self._device) - self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr) - - wp.launch( - kernel=fabric_utils.set_view_to_fabric_array, - dim=self._fabric_to_view.shape[0], - inputs=[self._fabric_to_view, self._view_to_fabric], - device=self._device, + self._local_ifa_rw = self._build_indexed_array( + self._local_sel_rw, self._LOCAL_MATRIX_NAME, self._local_rw_fabric_indices ) - wp.synchronize() + self._parent_world_ifa_ro = self._build_parent_indexed_array(self._trans_sel_ro) + # Pre-allocated reusable output buffers (world + local + scales). self._fabric_positions_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) self._fabric_orientations_buf = wp.zeros((self.count, 4), dtype=wp.float32, device=self._device) + self._fabric_scales_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) + self._fabric_local_translations_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) + self._fabric_local_orientations_buf = wp.zeros((self.count, 4), dtype=wp.float32, device=self._device) + self._fabric_empty_2d_array_sentinel = wp.zeros((0, 0), dtype=wp.float32, device=self._device) + self._fabric_positions_ta = ProxyArray(self._fabric_positions_buf) self._fabric_orientations_ta = ProxyArray(self._fabric_orientations_buf) - self._fabric_scales_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) - self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32, device=self._device) - self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") - self._fabric_stage = fabric_stage - self._fabric_device = self._device + self._fabric_local_translations_ta = ProxyArray(self._fabric_local_translations_buf) + self._fabric_local_orientations_ta = ProxyArray(self._fabric_local_orientations_buf) self._fabric_initialized = True - self._fabric_usd_sync_done = False - def _sync_fabric_from_usd_once(self) -> None: - """Sync Fabric world matrices from USD once, on the first read. + # Seed Fabric matrices from USD authoritatively. ``SetWorldXformFromUsd`` / + # ``SetLocalXformFromUsd`` are no-ops on freshly authored stages that haven't + # been rendered yet; we instead read through the USD view (children) and + # ``UsdGeom.XformCache`` (parents) and write via the same compose kernel that + # ``set_world_poses`` uses. + self._sync_fabric_from_usd_initial() - ``set_world_poses`` and ``set_scales`` each set ``_fabric_usd_sync_done`` - themselves, so no explicit flag assignment is needed here. - """ - if not self._fabric_initialized: - self._initialize_fabric() + def _sync_fabric_from_usd_initial(self) -> None: + """Populate Fabric world+local matrices for children and parents from USD. - positions_usd_ta, orientations_usd_ta = self._usd_view.get_world_poses() - positions_usd = positions_usd_ta.warp - orientations_usd = orientations_usd_ta.warp - scales_usd = self._usd_view.get_scales() + Performed once during ``_initialize_fabric``. Without this step Fabric's + matrices are identity for stages that haven't been rendered yet, and our + getters (which read from Fabric) would return wrong values. + """ + # --- Children --- + pos_ta, ori_ta = self._usd_view.get_world_poses() + scales_obj = self._usd_view.get_scales() + scales_wp = ( + scales_obj.warp + if hasattr(scales_obj, "warp") + else scales_obj + if isinstance(scales_obj, wp.array) + else self._fabric_empty_2d_array_sentinel + ) + local_pos_ta, local_ori_ta = self._usd_view.get_local_poses() + # Compose into child worldMatrix. + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=self.count, + inputs=[ + self._world_ifa_rw, + _to_float32_2d(pos_ta.warp), + _to_float32_2d(ori_ta.warp), + _to_float32_2d(scales_wp), + False, + False, + False, + self._view_indices, + ], + device=self._device, + ) + # Compose into child localMatrix. Pass the locally-authored scale so + # that a subsequent ``_sync_world_from_local_if_dirty`` produces the + # right world-space scale (``world = parent_world * local`` carries + # ``local``'s scale through the multiply). + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=self.count, + inputs=[ + self._local_ifa_rw, + _to_float32_2d(local_pos_ta.warp), + _to_float32_2d(local_ori_ta.warp), + _to_float32_2d(scales_wp), + False, + False, + False, + self._view_indices, + ], + device=self._device, + ) - self.set_world_poses(positions_usd, orientations_usd) - self.set_scales(scales_usd) + # --- Parents (one entry per unique parent path) --- + unique_parent_paths = list(dict.fromkeys(p.rsplit("/", 1)[0] for p in self.prim_paths)) + if unique_parent_paths: + usd_stage = sim_utils.get_current_stage() + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + world_pos_rows: list[list[float]] = [] + world_ori_rows: list[list[float]] = [] + world_scale_rows: list[list[float]] = [] + decomposer = Gf.Transform() + for path in unique_parent_paths: + prim = usd_stage.GetPrimAtPath(path) + tf = xform_cache.GetLocalToWorldTransform(prim) + # Extract scale before ``Orthonormalize`` strips it from the rows. + decomposer.SetMatrix(tf) + s = decomposer.GetScale() + tf.Orthonormalize() + t = tf.ExtractTranslation() + q = tf.ExtractRotationQuat() + img, real = q.GetImaginary(), q.GetReal() + world_pos_rows.append([float(t[0]), float(t[1]), float(t[2])]) + world_ori_rows.append([float(img[0]), float(img[1]), float(img[2]), float(real)]) + world_scale_rows.append([float(s[0]), float(s[1]), float(s[2])]) + parent_view_indices = wp.array(list(range(len(unique_parent_paths))), dtype=wp.uint32, device=self._device) + parent_pos_wp = wp.array(world_pos_rows, dtype=wp.float32, device=self._device) + parent_ori_wp = wp.array(world_ori_rows, dtype=wp.float32, device=self._device) + parent_scale_wp = wp.array(world_scale_rows, dtype=wp.float32, device=self._device) + # Compose worldMatrix for parents (use a one-shot indexed array against + # ``world_sel_rw`` keyed on the unique parent paths). + parent_world_rw = wp.indexedfabricarray( + fa=wp.fabricarray(self._world_sel_rw, self._WORLD_MATRIX_NAME), + indices=self._compute_fabric_indices_for(self._world_sel_rw, unique_parent_paths), + ) + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=len(unique_parent_paths), + inputs=[ + parent_world_rw, + parent_pos_wp, + parent_ori_wp, + parent_scale_wp, + False, + False, + False, + parent_view_indices, + ], + device=self._device, + ) + wp.synchronize() - def _resolve_indices_wp(self, indices: wp.array | None) -> wp.array: - """Resolve view indices as a Warp uint32 array.""" - if indices is None or indices == slice(None): - if self._default_view_indices is None: - raise RuntimeError("Fabric indices are not initialized.") - return self._default_view_indices - if indices.dtype != wp.uint32: - return wp.array(indices.numpy().astype("uint32"), dtype=wp.uint32, device=self._device) - return indices + # The child worldMatrix above was composed with the child's *local* scale, + # which is wrong whenever a parent has a non-unit world scale. Mark the + # view dirty so the next world read fires ``_sync_world_from_local_if_dirty`` + # and recomputes ``child_world = parent_world * child_local`` — that + # multiply produces the correct world-space scale because the parent and + # local matrices now both carry the right scale (seeded above). + self._world_dirty = True + + def _compute_fabric_indices_for(self, selection, paths: list[str]) -> wp.array: + """Path-dict lookup helper used to build one-shot indexed arrays for a custom path set.""" + fabric_paths = selection.GetPaths() + path_to_idx = {str(p): i for i, p in enumerate(fabric_paths)} + indices: list[int] = [] + for path in paths: + idx = path_to_idx.get(path) + if idx is None: + raise RuntimeError(f"Path '{path}' not found in Fabric selection.") + indices.append(idx) + return wp.array(indices, dtype=wp.int32, device=self._device) diff --git a/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py b/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py index f0c18ccb98c7..11d7e1ab717c 100644 --- a/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py +++ b/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py @@ -23,7 +23,7 @@ import torch # noqa: E402 import warp as wp # noqa: E402 from frame_view_contract_utils import * # noqa: F401, F403, E402 -from frame_view_contract_utils import CHILD_OFFSET, ViewBundle, test_set_world_updates_local # noqa: E402 +from frame_view_contract_utils import CHILD_OFFSET, ViewBundle # noqa: E402 from isaaclab_physx.sim.views import FabricFrameView as FrameView # noqa: E402 from pxr import Gf, UsdGeom # noqa: E402 @@ -106,28 +106,11 @@ def factory(num_envs: int, device: str) -> ViewBundle: # ------------------------------------------------------------------ -# Override shared contract test with expected failure for Fabric. -# FabricFrameView.set_world_poses writes to Fabric worldMatrix only; the local -# pose (read via USD) does not reflect the change because there is no -# Fabric → USD writeback for local poses. This is tracked as Issue #5 -# (localMatrix: set_local_poses falls back to USD). +# Override: ensure the shared contract test runs without xfail now that +# get_local_poses computes local from Fabric world matrices. # ------------------------------------------------------------------ - - -@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) -@pytest.mark.xfail( - reason=( - "Issue #5: FabricFrameView.set_world_poses writes to Fabric worldMatrix only. " - "get_local_poses reads from stale USD because there is no Fabric→USD " - "writeback for local poses." - ), - strict=True, -) -def test_set_world_updates_local(device, view_factory): # noqa: F811 - """Override the shared test to mark it as expected failure.""" - from frame_view_contract_utils import test_set_world_updates_local as _impl # noqa: PLC0415 - - _impl(device, view_factory) +# (No override needed — the shared test_set_world_updates_local from +# frame_view_contract_utils is imported via wildcard and will run as-is.) # ------------------------------------------------------------------ @@ -188,48 +171,327 @@ def test_fabric_set_world_does_not_write_back_to_usd(device, view_factory): @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) -def test_fabric_rebuild_after_topology_change(device, view_factory, monkeypatch): - """Forcing the topology-changed branch on a write triggers - :meth:`_rebuild_fabric_arrays` and leaves the view in a state where - subsequent writes/reads still produce correct data. - - Real ``PrimSelection.PrepareForReuse`` reports topology change only when - Fabric reallocates internally, which is hard to provoke from a unit test. - Instead we monkeypatch ``_prepare_for_reuse`` on the instance to always - take the rebuild branch and verify the view remains usable. +def test_fabric_rebuild_after_topology_change(device, view_factory): + """A simulated topology change rebuilds the indexed fabric arrays and leaves + the view in a state where subsequent writes/reads still produce correct data. + + Real ``PrimSelection.PrepareForReuse`` reports topology change only when Fabric + reallocates internally, which is hard to provoke from a unit test. Instead we + invoke :meth:`FabricFrameView._compute_fabric_indices` and rebuild the indexed + arrays manually, mimicking what ``_get_*_array`` would do on a real topology + event, then verify a roundtrip still works. """ bundle = view_factory(2, device) view = bundle.view - # First write — initializes Fabric and binds _fabric_selection. + # First write — initializes Fabric. initial = wp.zeros((2, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=2, inputs=[initial, 1.0, 2.0, 3.0], device=device) view.set_world_poses(positions=initial) - rebuild_calls = [] - real_rebuild = view._rebuild_fabric_arrays - - def spy_rebuild(): - rebuild_calls.append(True) - real_rebuild() - - def force_topology_changed(): - if view._fabric_selection is not None: - view._fabric_selection.PrepareForReuse() - spy_rebuild() - - monkeypatch.setattr(view, "_prepare_for_reuse", force_topology_changed) + # Simulate topology change: recompute per-selection fabric indices and rebuild + # every indexed array, mirroring the lazy paths in the ``_get_*_array`` accessors. + view._rebuild_trans_ro_arrays() + view._world_rw_fabric_indices = view._compute_fabric_indices(view._world_sel_rw) + view._world_ifa_rw = view._build_indexed_array( + view._world_sel_rw, view._WORLD_MATRIX_NAME, view._world_rw_fabric_indices + ) + view._local_rw_fabric_indices = view._compute_fabric_indices(view._local_sel_rw) + view._local_ifa_rw = view._build_indexed_array( + view._local_sel_rw, view._LOCAL_MATRIX_NAME, view._local_rw_fabric_indices + ) - # Trigger another write — goes through the forced topology-change branch. + # Trigger another write through the rebuilt arrays. new = wp.zeros((2, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=2, inputs=[new, 4.0, 5.0, 6.0], device=device) view.set_world_poses(positions=new) - assert rebuild_calls, "Forced topology-change branch did not invoke _rebuild_fabric_arrays" - - # Read back — proves the rebuilt _view_to_fabric and _fabric_world_matrices - # are still consistent. ret_pos, _ = view.get_world_poses() pos_torch = wp.to_torch(ret_pos) expected = torch.tensor([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]], device=device) - assert torch.allclose(pos_torch, expected, atol=1e-7), f"Read after rebuild failed on {device}: {pos_torch}" + # 1e-5 ≈ 20 ULP at magnitudes ~4-6; absorbs float32 SRT compose/decompose drift. + assert torch.allclose(pos_torch, expected, atol=1e-5), f"Read after rebuild failed on {device}: {pos_torch}" + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_prepare_for_reuse_detects_topology_change(device, view_factory): + """Each persistent ``PrimSelection`` exposes ``PrepareForReuse`` and returns a + bool. When the underlying Fabric topology is unchanged it returns False. + """ + bundle = view_factory(1, device) + view = bundle.view + view.get_world_poses() # trigger Fabric init + + assert view._trans_sel_ro is not None, "trans_sel_ro selection not initialized" + for selection in (view._trans_sel_ro, view._world_sel_rw, view._local_sel_rw): + result = selection.PrepareForReuse() + assert isinstance(result, bool), f"PrepareForReuse should return bool, got {type(result)}" + assert not result, "PrepareForReuse should return False when no topology change" + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_set_local_via_fabric_path(device, view_factory): + """Exercise the Fabric-native set_local_poses path. + + Ensures set_local_poses computes child_world = parent_world * local + entirely within Fabric (not falling back to USD) by first triggering + the Fabric sync via get_world_poses. + """ + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + + # Trigger lazy `_initialize_fabric()` so subsequent calls take the Fabric path. + view.get_world_poses() + + # Now set_local_poses should take the Fabric path + new_local_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_pos, 1.0, 2.0, 3.0], device=device) + ori = torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device) + new_local_ori = wp.from_torch(ori) + + view.set_local_poses(translations=new_local_pos, orientations=new_local_ori) + + # Verify: world = parent(0,0,1) + local(1,2,3) = (1,2,4) + world_pos, _ = view.get_world_poses() + expected = torch.tensor([[1.0, 2.0, 4.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(world_pos.torch, expected, atol=1e-4, rtol=0) + + # Verify get_local_poses returns the local offset + local_pos, _ = view.get_local_poses() + expected_local = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(local_pos.torch, expected_local, atol=1e-4, rtol=0) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_get_scales_fabric_path(device, view_factory): + """Exercise the Fabric-native get_scales path.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + + # Trigger lazy `_initialize_fabric()` so the get_scales call below uses Fabric. + view.get_world_poses() + + scales = view.get_scales() + scales_t = wp.to_torch(scales) + # Default scale should be (1, 1, 1) + expected = torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(scales_t, expected, atol=1e-4, rtol=0) + + +# ------------------------------------------------------------------ +# Transpose-convention verification: world ↔ local kernels rely on the +# identity ``(A·B)ᵀ = Bᵀ·Aᵀ`` to drop explicit transposes when operating +# on Fabric's column-transposed matrix storage. The translation-only +# parents used by the standard fixture cannot distinguish the right +# convention from the wrong one — the rotation block is identity and +# equals its own transpose. These tests use a parent rotated 90° around +# Z so that an incorrect storage convention would produce a clearly +# wrong child pose. +# ------------------------------------------------------------------ + + +# Parent at (0, 0, 1) rotated +90° around Z (so the parent X axis points +# along world +Y). Quaternion components in (x, y, z, w) order. +_ROTATED_PARENT_POS = (0.0, 0.0, 1.0) +_ROTATED_PARENT_QUAT_XYZW = (0.0, 0.0, 0.70710678, 0.70710678) + + +def _build_rotated_parent_view(device: str) -> "FrameView": + """Build a 1-env FabricFrameView whose parent is rotated 90° around Z.""" + stage = sim_utils.get_current_stage() + sim_utils.create_prim( + "/World/Parent_0", + "Xform", + translation=_ROTATED_PARENT_POS, + orientation=_ROTATED_PARENT_QUAT_XYZW, + stage=stage, + ) + sim_utils.create_prim("/World/Parent_0/Child", "Camera", translation=(0.0, 0.0, 0.0), stage=stage) + sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view = FrameView("/World/Parent_.*/Child", device=device) + view.get_world_poses() # force Fabric init and USD→Fabric seed + return view + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_set_local_then_get_world_with_rotated_parent(device): + """Verify ``update_indexed_world_matrix_from_local`` under non-identity parent rotation. + + With parent rotated +90° around Z, a child local translation of (1, 0, 0) + must produce world translation (0, 1, 1) — parent_pos + R · local. If the + transpose convention in the kernel were wrong, the rotation would flip + direction and the world position would land at (0, -1, 1) instead. + """ + _skip_if_unavailable(device) + view = _build_rotated_parent_view(device) + + new_local = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local, 1.0, 0.0, 0.0], device=device) + identity_quat = wp.from_torch(torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device)) + view.set_local_poses(translations=new_local, orientations=identity_quat) + + world_pos, _ = view.get_world_poses() + expected = torch.tensor([[0.0, 1.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(world_pos.torch, expected, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_set_world_then_get_local_with_rotated_parent(device): + """Verify ``update_indexed_local_matrix_from_world`` under non-identity parent rotation. + + With parent rotated +90° around Z and at (0, 0, 1), writing child world + translation (5, 0, 2) must yield child local translation Rᵀ · (5, 0, 1) = + (0, -5, 1). A wrong transpose convention would invert the rotation in the + wrong direction and produce (0, 5, 1) instead. + """ + _skip_if_unavailable(device) + view = _build_rotated_parent_view(device) + + new_world = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_world, 5.0, 0.0, 2.0], device=device) + view.set_world_poses(positions=new_world) + + local_pos, _ = view.get_local_poses() + expected = torch.tensor([[0.0, -5.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(local_pos.torch, expected, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_initial_seed_with_scaled_parent(device): + """Verify the initial USD→Fabric seed handles non-unit scales correctly. + + Sets up a parent with world scale (2, 1, 1) and a child with local scale + (3, 1, 1) at local translation (1, 0, 0). Expected world-space values for + the child: + + * world scale = parent_scale * child_local_scale = (6, 1, 1) + * world position = parent_pos + parent_scale * child_local_pos + = (0, 0, 1) + (2 * 1, 0, 0) = (2, 0, 1) + + If the parent's worldMatrix is seeded with a hardcoded unit scale, + ``get_scales`` returns (3, 1, 1) instead of (6, 1, 1) and ``get_world_poses`` + returns (1, 0, 1) instead of (2, 0, 1). If the child's localMatrix is + seeded without scale, after ``_sync_world_from_local_if_dirty`` the world + scale collapses to (2, 1, 1). This test catches both regressions. + """ + _skip_if_unavailable(device) + stage = sim_utils.get_current_stage() + sim_utils.create_prim("/World/Parent_0", "Xform", translation=(0.0, 0.0, 1.0), scale=(2.0, 1.0, 1.0), stage=stage) + sim_utils.create_prim( + "/World/Parent_0/Child", + "Camera", + translation=(1.0, 0.0, 0.0), + scale=(3.0, 1.0, 1.0), + stage=stage, + ) + sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view = FrameView("/World/Parent_.*/Child", device=device) + + world_pos, _ = view.get_world_poses() + torch.testing.assert_close( + world_pos.torch, + torch.tensor([[2.0, 0.0, 1.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + scales = wp.to_torch(view.get_scales()) + torch.testing.assert_close( + scales, + torch.tensor([[6.0, 1.0, 1.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + +# ------------------------------------------------------------------ +# Multi-view per stage: per-view dirty-flag isolation +# ------------------------------------------------------------------ + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_multi_view_per_view_dirty_isolation(device): + """Two ``FabricFrameView`` instances on the same stage must not clear each other's + pending local→world sync. + + Background: an earlier implementation stored the world-dirty flag at the class + level keyed by ``stage_id``. With two views on the same stage, view B reading + worlds would clear the flag set by view A's ``set_local_poses``, leaving A's + world matrices silently stale because A's per-view sync kernel never fired. + + This test sets up two views over disjoint child prims (under different parent + sub-trees of the same stage), interleaves their writes and reads, and verifies: + + * view A's ``set_local_poses`` only dirties view A + * view B's ``get_world_poses`` does not clear view A's flag + * after both views' world reads, each one's worlds reflect its own latest local + * neither view's reads/writes corrupt the other view's poses + """ + _skip_if_unavailable(device) + stage = sim_utils.get_current_stage() + + # Two disjoint sub-trees under the same stage. Use different parent names so + # the regex patterns for the two views don't accidentally overlap. + sim_utils.create_prim("/World/EnvA_0", "Xform", translation=(0.0, 0.0, 1.0), stage=stage) + sim_utils.create_prim("/World/EnvA_0/ChildA", "Camera", translation=(0.1, 0.0, 0.0), stage=stage) + sim_utils.create_prim("/World/EnvB_0", "Xform", translation=(0.0, 0.0, 2.0), stage=stage) + sim_utils.create_prim("/World/EnvB_0/ChildB", "Camera", translation=(0.2, 0.0, 0.0), stage=stage) + + sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view_a = FrameView("/World/EnvA_.*/ChildA", device=device) + view_b = FrameView("/World/EnvB_.*/ChildB", device=device) + + # Initial reads — triggers Fabric init + the seed-time ``_world_dirty = True`` + # path on both views, then clears it. + expected_a0 = torch.tensor([[0.1, 0.0, 1.0]], dtype=torch.float32, device=device) + expected_b0 = torch.tensor([[0.2, 0.0, 2.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(view_a.get_world_poses()[0].torch, expected_a0, atol=1e-5, rtol=0) + torch.testing.assert_close(view_b.get_world_poses()[0].torch, expected_b0, atol=1e-5, rtol=0) + assert view_a._world_dirty is False + assert view_b._world_dirty is False + # Both views must reuse the same cached IFabricHierarchy (one stage = one handle). + assert view_a._fabric_hierarchy is view_b._fabric_hierarchy + from isaaclab_physx.sim.fabric_stage_cache import FabricStageCache + + sim_context = sim_utils.SimulationContext.instance() + cache = sim_context.get_service(FabricStageCache) + assert cache is not None + assert len(cache._hierarchy_cache) == 1 + + # Write a new local pose on view A only. + new_local_a = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_a, 1.0, 0.0, 0.0], device=device) + identity_quat = wp.from_torch(torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device)) + view_a.set_local_poses(translations=new_local_a, orientations=identity_quat) + + # Only view A should be dirty. Critical: a per-stage flag would have dirtied + # both views (or neither) at this point. + assert view_a._world_dirty is True, "set_local_poses should mark its own view dirty" + assert view_b._world_dirty is False, "set_local_poses on view A must not dirty view B" + + # Read worlds from view B FIRST. With a per-stage flag, B's + # ``_sync_world_from_local_if_dirty`` would fire and clear the flag, leaving A + # stale. With the per-view flag, B's read is a no-op sync-wise. + torch.testing.assert_close(view_b.get_world_poses()[0].torch, expected_b0, atol=1e-5, rtol=0) + assert view_b._world_dirty is False + assert view_a._world_dirty is True, "view B's world read must not clear view A's dirty flag" + + # Now read view A's worlds — sync fires, world reflects the new local. + expected_a1 = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(view_a.get_world_poses()[0].torch, expected_a1, atol=1e-5, rtol=0) + assert view_a._world_dirty is False + + # Symmetric pass: write on B, ensure A is undisturbed. + new_local_b = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_b, 3.0, 0.0, 0.0], device=device) + view_b.set_local_poses(translations=new_local_b, orientations=identity_quat) + assert view_a._world_dirty is False + assert view_b._world_dirty is True + + # A's worlds must still read back the post-set-local value from above; no + # cross-view stomp on the world matrix. + torch.testing.assert_close(view_a.get_world_poses()[0].torch, expected_a1, atol=1e-5, rtol=0) + expected_b1 = torch.tensor([[3.0, 0.0, 2.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(view_b.get_world_poses()[0].torch, expected_b1, atol=1e-5, rtol=0) + assert view_a._world_dirty is False + assert view_b._world_dirty is False diff --git a/uv.lock b/uv.lock new file mode 100644 index 000000000000..04b814a22708 --- /dev/null +++ b/uv.lock @@ -0,0 +1,9 @@ +version = 1 +revision = 3 +requires-python = ">=3.12" + +[options] +prerelease-mode = "allow" + +[manifest] +overrides = [{ name = "numpy", specifier = ">=2" }]