diff --git a/tests/test_constraints.py b/tests/test_constraints.py index a98833e48..b7072c0db 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -909,6 +909,34 @@ def test_atom_indexed_constraint_update_and_select() -> None: assert sub_constraint is None +def test_constraint_atom_idx_remapped_on_reordered_slice( + mixed_double_sim_state: ts.SimState, +) -> None: + """Slicing a batched state with reversed system order must remap constraint + atom_idx so the constraint targets the same physical atoms in the new ordering. + """ + state = mixed_double_sim_state + n_sys0 = state.n_atoms_per_system[0].item() + fix_global_idx = n_sys0 + 1 # second atom of system 1 + state.constraints = [FixAtoms(atom_idx=torch.tensor([fix_global_idx]))] + + # Reverse system order: system 1's atoms come first in the new ordering, + # so the fixed atom (second of old system 1) lands at new index 1. + sliced = state[[1, 0]] + expected_atom_idx = 1 + + constraint = sliced.constraints[0] + assert isinstance(constraint, FixAtoms) + + assert constraint.atom_idx.item() == expected_atom_idx + + forces = torch.ones(sliced.n_atoms, 3, dtype=DTYPE) + constraint.adjust_forces(sliced, forces) + assert torch.all(forces[expected_atom_idx] == 0.0) + assert torch.all(forces[:expected_atom_idx] == 1.0) + assert torch.all(forces[expected_atom_idx + 1 :] == 1.0) + + def test_merge_constraints(mixed_double_sim_state: ts.SimState) -> None: """Test merge_constraints combines constraints from multiple systems.""" # Split the double system state diff --git a/torch_sim/state.py b/torch_sim/state.py index 1a1eff180..bd0dd1f2a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -796,6 +796,14 @@ def _filter_attrs_by_index( if (c := con.select_constraint(atom_mask, system_mask)) ] + # Remap constraint atom_idx to account for reordering by atom_indices + atom_remap = torch.empty(state.n_atoms, dtype=torch.long, device=state.device) + atom_remap[atom_indices] = torch.arange(len(atom_indices), device=state.device) + new_atom_idx = atom_remap[torch.where(atom_mask)[0]] + for c in filtered_attrs["_constraints"]: + if hasattr(c, "atom_idx") and isinstance(c.atom_idx, torch.Tensor): + c.atom_idx = new_atom_idx[c.atom_idx] + # Build inverse map for system_idx remapping (old index -> new position) if len(system_indices) == 0: inv = torch.empty(0, device=state.device, dtype=torch.long)