From c44d4f9037028874912c5e1cb11b2bb09ef127ac Mon Sep 17 00:00:00 2001 From: falletta Date: Wed, 18 Feb 2026 07:12:18 -0800 Subject: [PATCH 1/3] fix atom_idx remap --- tests/test_constraints.py | 27 +++++++++++++++++++++++++++ torch_sim/state.py | 8 ++++++++ 2 files changed, 35 insertions(+) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index a98833e48..49b5532c9 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -909,6 +909,33 @@ def test_atom_indexed_constraint_update_and_select() -> None: assert sub_constraint is None +def test_constraint_atom_idx_remapped_on_reordered_slice( + mixed_double_sim_state: ts.SimState, +) -> None: + """Slicing a batched state with reversed system order must remap constraint + atom_idx so the constraint targets the same physical atoms in the new ordering. + """ + state = mixed_double_sim_state + n_sys0 = (state.system_idx == 0).sum().item() + fix_global_idx = n_sys0 + 1 # second atom of system 1 + state.constraints = [FixAtoms(atom_idx=torch.tensor([fix_global_idx]))] + + # Reverse system order: system 1's atoms come first in the new ordering, + # so the fixed atom (second of old system 1) lands at new index 1. + sliced = state[[1, 0]] + + constraint = sliced.constraints[0] + assert isinstance(constraint, FixAtoms) + + assert constraint.atom_idx.item() == 1 + + forces = torch.ones(sliced.n_atoms, 3, dtype=DTYPE) + constraint.adjust_forces(sliced, forces) + assert torch.all(forces[1] == 0.0) + assert torch.all(forces[:1] == 1.0) + assert torch.all(forces[2:] == 1.0) + + def test_merge_constraints(mixed_double_sim_state: ts.SimState) -> None: """Test merge_constraints combines constraints from multiple systems.""" # Split the double system state diff --git a/torch_sim/state.py b/torch_sim/state.py index 1a1eff180..8fe5479bd 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -796,6 +796,14 @@ def _filter_attrs_by_index( if (c := con.select_constraint(atom_mask, system_mask)) ] + # Remap constraint atom_idx to account for reordering by atom_indices + atom_remap = torch.empty(state.n_atoms, dtype=torch.long, device=state.device) + atom_remap[atom_indices] = torch.arange(len(atom_indices), device=state.device) + dense_to_reordered = atom_remap[torch.where(atom_mask)[0]] + for c in filtered_attrs["_constraints"]: + if hasattr(c, "atom_idx") and isinstance(c.atom_idx, torch.Tensor): + c.atom_idx = dense_to_reordered[c.atom_idx] + # Build inverse map for system_idx remapping (old index -> new position) if len(system_indices) == 0: inv = torch.empty(0, device=state.device, dtype=torch.long) From 67c6efa3f536d5fed92b8236d8a9cd00da7e4fd7 Mon Sep 17 00:00:00 2001 From: falletta Date: Wed, 18 Feb 2026 09:01:17 -0800 Subject: [PATCH 2/3] minor revisions --- tests/test_constraints.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 49b5532c9..b7072c0db 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -916,24 +916,25 @@ def test_constraint_atom_idx_remapped_on_reordered_slice( atom_idx so the constraint targets the same physical atoms in the new ordering. """ state = mixed_double_sim_state - n_sys0 = (state.system_idx == 0).sum().item() + n_sys0 = state.n_atoms_per_system[0].item() fix_global_idx = n_sys0 + 1 # second atom of system 1 state.constraints = [FixAtoms(atom_idx=torch.tensor([fix_global_idx]))] # Reverse system order: system 1's atoms come first in the new ordering, # so the fixed atom (second of old system 1) lands at new index 1. sliced = state[[1, 0]] + expected_atom_idx = 1 constraint = sliced.constraints[0] assert isinstance(constraint, FixAtoms) - assert constraint.atom_idx.item() == 1 + assert constraint.atom_idx.item() == expected_atom_idx forces = torch.ones(sliced.n_atoms, 3, dtype=DTYPE) constraint.adjust_forces(sliced, forces) - assert torch.all(forces[1] == 0.0) - assert torch.all(forces[:1] == 1.0) - assert torch.all(forces[2:] == 1.0) + assert torch.all(forces[expected_atom_idx] == 0.0) + assert torch.all(forces[:expected_atom_idx] == 1.0) + assert torch.all(forces[expected_atom_idx + 1 :] == 1.0) def test_merge_constraints(mixed_double_sim_state: ts.SimState) -> None: From 434a364a324a8fb5006003ca4a441594242fdb76 Mon Sep 17 00:00:00 2001 From: falletta Date: Wed, 18 Feb 2026 12:56:57 -0800 Subject: [PATCH 3/3] renamed var --- torch_sim/state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 8fe5479bd..bd0dd1f2a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -799,10 +799,10 @@ def _filter_attrs_by_index( # Remap constraint atom_idx to account for reordering by atom_indices atom_remap = torch.empty(state.n_atoms, dtype=torch.long, device=state.device) atom_remap[atom_indices] = torch.arange(len(atom_indices), device=state.device) - dense_to_reordered = atom_remap[torch.where(atom_mask)[0]] + new_atom_idx = atom_remap[torch.where(atom_mask)[0]] for c in filtered_attrs["_constraints"]: if hasattr(c, "atom_idx") and isinstance(c.atom_idx, torch.Tensor): - c.atom_idx = dense_to_reordered[c.atom_idx] + c.atom_idx = new_atom_idx[c.atom_idx] # Build inverse map for system_idx remapping (old index -> new position) if len(system_indices) == 0: