From 28bc900ee81292a2d8a96349ca2ee4db99f799a7 Mon Sep 17 00:00:00 2001 From: Pascal Salzbrenner Date: Thu, 23 Apr 2026 16:06:36 +0000 Subject: [PATCH 1/7] FIx bugs with constrained optimisation (various missing updates / inconsistencies with atom positions when cell is adjusted, eg due to symmetry) --- torch_sim/optimizers/fire.py | 27 ++++++++++++++++++++ torch_sim/optimizers/lbfgs.py | 47 +++++++++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index e45c7ec8d..66d8b83fc 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -449,6 +449,33 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 # (needed for correct displacement calculation in position constraints) state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) + # Resync cell_positions to match the (possibly adjusted) cell so + # the next step builds on the correct base instead of the + # pre-adjustment value. Without this, any constraint that + # modifies the cell (e.g. FixSymmetry) causes a zigzag where the + # optimizer repeatedly proposes from a stale cell_positions. + adjusted_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) + if is_frechet: + cell_factor_reshaped = state.cell_factor.view( + state.n_systems, 1, 1 + ) + state.cell_positions = ( + tsm.matrix_log_33( + adjusted_deform_grad, sim_dtype=state.dtype + ) + * cell_factor_reshaped + ) + else: + cell_factor_expanded = state.cell_factor.expand( + state.n_systems, 3, 1 + ) + state.cell_positions = ( + adjusted_deform_grad.reshape(state.n_systems, 3, 3) + * cell_factor_expanded + ) + # Transform fractional positions to Cartesian using NEW deformation gradient new_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index cb88c6577..482fde6cc 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -464,9 +464,12 @@ def lbfgs_step( # noqa: PLR0915, C901 # Save previous state for history update # For cell state: store fractional positions and scaled forces (ASE convention) if isinstance(state, CellLBFGSState): - state.prev_positions = frac_positions.clone() # [N, 3] (fractional) - state.prev_forces = forces_scaled.clone() # [N, 3] (scaled) + # Store cell prev state BEFORE the cell step so that the history + # update computes s_new_cell = resynced_current - start_of_step, + # correctly capturing the actual cell displacement. state.prev_cell_positions = state.cell_positions.clone() # [S, 3, 3] + # prev_cell_forces comes from compute_cell_forces which already uses + # the adjusted cell, so it's in the correct frame. state.prev_cell_forces = state.cell_forces.clone() # [S, 3, 3] # Apply cell step @@ -499,12 +502,46 @@ def lbfgs_step( # noqa: PLR0915, C901 ) # [S, 3, 3] state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) + # Resync cell_positions to match the (possibly adjusted) cell so + # the next step builds on the correct base instead of the + # pre-adjustment value. Without this, any constraint that + # modifies the cell (e.g. FixSymmetry) causes a zigzag where the + # optimizer repeatedly proposes from a stale cell_positions. + adjusted_deform_grad = deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + if is_frechet: + cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1) + state.cell_positions = ( + ts.math.matrix_log_33( + adjusted_deform_grad, sim_dtype=state.dtype + ) + * cell_factor_reshaped + ) + else: + cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) + state.cell_positions = ( + adjusted_deform_grad.reshape(n_systems, 3, 3) + * cell_factor_expanded + ) + + # Store prev_positions/prev_forces in the ADJUSTED cell's frame so + # they are consistent with the next step's start-of-step deformation + # gradient. Without this, the LBFGS history vectors (s, y) mix two + # different coordinate frames, corrupting the Hessian estimate. + state.prev_positions = torch.linalg.solve( + adjusted_deform_grad[state.system_idx], + state.positions.unsqueeze(-1), + ).squeeze(-1) # [N, 3] (fractional in adjusted frame) + state.prev_forces = torch.bmm( + state.forces.unsqueeze(1), + adjusted_deform_grad[state.system_idx], + ).squeeze(1) # [N, 3] (scaled in adjusted frame) + # Apply position step in fractional space, then convert to Cartesian new_frac = frac_positions + step_positions # [N, 3] - new_deform_grad = deform_grad( - state.reference_cell.mT, state.row_vector_cell - ) # [S, 3, 3] + new_deform_grad = adjusted_deform_grad # already computed above # new_positions = new_frac @ deform_grad^T new_positions = torch.bmm( new_frac.unsqueeze(1), # [N, 1, 3] From 02819b2e928c9e7ebd98c92905bf6275e0deab86 Mon Sep 17 00:00:00 2001 From: Pascal Salzbrenner Date: Fri, 24 Apr 2026 09:38:11 +0000 Subject: [PATCH 2/7] Remove max_strain clamping which was hindering / preventing full relaxation with FixSymmetry --- torch_sim/constraints.py | 72 +++------------------------------------- 1 file changed, 5 insertions(+), 67 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index c6073f013..66231b26f 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -774,10 +774,8 @@ class FixSymmetry(SystemConstraint): rotations: list[torch.Tensor] symm_maps: list[torch.Tensor] - reference_cells: list[torch.Tensor] | None do_adjust_positions: bool do_adjust_cell: bool - max_cumulative_strain: float def __init__( self, @@ -787,8 +785,6 @@ def __init__( *, adjust_positions: bool = True, adjust_cell: bool = True, - reference_cells: list[torch.Tensor] | None = None, - max_cumulative_strain: float = 0.5, ) -> None: """Initialize FixSymmetry constraint. @@ -798,11 +794,6 @@ def __init__( system_idx: System indices (defaults to 0..n_systems-1). adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. - reference_cells: Initial refined cells (row vectors) per system for - cumulative strain tracking. If None, cumulative check is skipped. - max_cumulative_strain: Maximum allowed cumulative strain from the - reference cell. If exceeded, the cell update is clamped to - keep the structure within this strain envelope. """ n_systems = len(rotations) if len(symm_maps) != n_systems: @@ -817,19 +808,12 @@ def __init__( raise ValueError( f"system_idx length ({len(system_idx)}) != n_systems ({n_systems})" ) - if reference_cells is not None and len(reference_cells) != n_systems: - raise ValueError( - f"reference_cells length ({len(reference_cells)}) " - f"!= n_systems ({n_systems})" - ) super().__init__(system_idx=system_idx) self.rotations = rotations self.symm_maps = symm_maps - self.reference_cells = reference_cells self.do_adjust_positions = adjust_positions self.do_adjust_cell = adjust_cell - self.max_cumulative_strain = max_cumulative_strain @classmethod def from_state( @@ -866,7 +850,7 @@ def from_state( from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry - rotations, symm_maps, reference_cells = [], [], [] + rotations, symm_maps = [], [] cumsum = _cumsum_with_zero(state.n_atoms_per_system) for sys_idx in range(state.n_systems): @@ -896,8 +880,6 @@ def from_state( rotations.append(rots) symm_maps.append(smap) - # Store the refined cell as the reference for cumulative strain tracking - reference_cells.append(state.row_vector_cell[sys_idx].clone()) return cls( rotations, @@ -905,7 +887,6 @@ def from_state( system_idx=torch.arange(state.n_systems, device=state.device), adjust_positions=adjust_positions, adjust_cell=adjust_cell, - reference_cells=reference_cells, ) def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: @@ -938,10 +919,8 @@ def adjust_cell(self, state: SimState, cell: torch.Tensor) -> None: Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``. - Also checks cumulative strain from the initial reference cell. If the - total deformation exceeds ``max_cumulative_strain``, the update is - clamped to prevent phase transitions that would break the symmetry - constraint (e.g. hexagonal → tetragonal cell collapse). + Per-step deformation is clamped at 0.25 to avoid ill-conditioned + symmetrization, matching the ASE FixSymmetry behaviour. Args: state: Current simulation state. @@ -961,8 +940,7 @@ def adjust_cell(self, state: SimState, cell: torch.Tensor) -> None: new_row = cell[si].mT # column → row convention # Per-step deformation: clamp large steps to avoid ill-conditioned - # symmetrization while still making progress. The cumulative strain - # guard below is the real safety net against phase transitions. + # symmetrization while still making progress. deform_delta = torch.linalg.solve(cur_cell, new_row) - identity max_delta = torch.abs(deform_delta).max().item() if not math.isfinite(max_delta): @@ -978,17 +956,6 @@ def adjust_cell(self, state: SimState, cell: torch.Tensor) -> None: sym_delta = symmetrize_rank2(cur_cell, deform_delta, rots) proposed_cell = cur_cell @ (sym_delta + identity) - # Cumulative strain check against reference cell - if self.reference_cells is not None: - ref_cell = self.reference_cells[ci].to( - device=state.device, dtype=state.dtype - ) - cumulative_strain = torch.linalg.solve(ref_cell, proposed_cell) - identity - max_cumulative = torch.abs(cumulative_strain).max().item() - if max_cumulative > self.max_cumulative_strain: - scale = self.max_cumulative_strain / max_cumulative - proposed_cell = ref_cell @ (cumulative_strain * scale + identity) - cell[si] = proposed_cell.mT # back to column convention def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: @@ -1018,8 +985,6 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 self.system_idx + system_offset, adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, - reference_cells=list(self.reference_cells) if self.reference_cells else None, - max_cumulative_strain=self.max_cumulative_strain, ) def to( @@ -1034,12 +999,6 @@ def to( self.system_idx.to(device=device), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, - reference_cells=( - [c.to(device=device, dtype=dtype) for c in self.reference_cells] - if self.reference_cells is not None - else None - ), - max_cumulative_strain=self.max_cumulative_strain, ) @classmethod @@ -1051,32 +1010,21 @@ def merge(cls, constraints: list[Constraint]) -> Self: if any( c.do_adjust_positions != fix_sym_constraints[0].do_adjust_positions or c.do_adjust_cell != fix_sym_constraints[0].do_adjust_cell - or c.max_cumulative_strain != fix_sym_constraints[0].max_cumulative_strain for c in fix_sym_constraints[1:] ): raise ValueError( "Cannot merge FixSymmetry constraints with different " - "adjust_positions/adjust_cell/max_cumulative_strain settings" + "adjust_positions/adjust_cell settings" ) rotations = [r for c in fix_sym_constraints for r in c.rotations] symm_maps = [s for c in fix_sym_constraints for s in c.symm_maps] system_idx = torch.cat([c.system_idx for c in fix_sym_constraints]) - # Merge reference cells if all constraints have them - ref_cells = None - if all(c.reference_cells is not None for c in fix_sym_constraints): - ref_cells = [] - for c in fix_sym_constraints: - refs = c.reference_cells - if refs is not None: - ref_cells.extend(refs) return cls( rotations, symm_maps, system_idx=system_idx, adjust_positions=fix_sym_constraints[0].do_adjust_positions, adjust_cell=fix_sym_constraints[0].do_adjust_cell, - reference_cells=ref_cells, - max_cumulative_strain=fix_sym_constraints[0].max_cumulative_strain, ) def select_constraint( @@ -1090,19 +1038,12 @@ def select_constraint( if not mask.any(): return None local_idx = mask.nonzero(as_tuple=False).flatten().tolist() - ref_cells = ( - [self.reference_cells[idx] for idx in local_idx] - if self.reference_cells - else None - ) return type(self)( [self.rotations[idx] for idx in local_idx], [self.symm_maps[idx] for idx in local_idx], _mask_constraint_indices(self.system_idx[mask], system_mask), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, - reference_cells=ref_cells, - max_cumulative_strain=self.max_cumulative_strain, ) def select_sub_constraint( @@ -1114,15 +1055,12 @@ def select_sub_constraint( if sys_idx not in self.system_idx: return None local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() - ref_cells = [self.reference_cells[local]] if self.reference_cells else None return type(self)( [self.rotations[local]], [self.symm_maps[local]], torch.tensor([0], device=self.system_idx.device), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, - reference_cells=ref_cells, - max_cumulative_strain=self.max_cumulative_strain, ) def __repr__(self) -> str: From 74ef353bb9ac81a517e40db3ae865092624de886 Mon Sep 17 00:00:00 2001 From: Pascal Salzbrenner Date: Fri, 24 Apr 2026 13:47:28 +0000 Subject: [PATCH 3/7] Add tests for found bugs / new additions --- tests/test_fix_symmetry.py | 184 ++++++++++++++++++++++++++++--------- tests/test_optimizers.py | 53 +++++++++++ 2 files changed, 194 insertions(+), 43 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 87ef3ee5a..bf1b6222c 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -15,6 +15,9 @@ from torch_sim.constraints import FixCom, FixSymmetry from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.optimizers.cell_filters import CellLBFGSState, deform_grad +from torch_sim.optimizers.fire import fire_init, fire_step +from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step from torch_sim.symmetrize import get_symmetry_datasets @@ -285,9 +288,8 @@ def test_large_deformation_clamped(self) -> None: assert not torch.allclose(new_cell, orig_cell * 1.5, atol=1e-6) # Per-step clamp limits single-step strain to 0.25 identity = torch.eye(3, dtype=DTYPE) - assert constraint.reference_cells is not None - ref_cell = constraint.reference_cells[0] - strain = torch.linalg.solve(ref_cell, new_cell[0].mT) - identity + cur_cell = state.row_vector_cell[0] + strain = torch.linalg.solve(cur_cell, new_cell[0].mT) - identity assert torch.abs(strain).max().item() <= 0.25 + 1e-6 def test_nan_deformation_raises(self) -> None: @@ -300,15 +302,11 @@ def test_nan_deformation_raises(self) -> None: constraint.adjust_cell(state, new_cell) def test_init_mismatched_lengths_raises(self) -> None: - """Mismatched rotations/symm_maps/reference_cells lengths raise ValueError.""" + """Mismatched rotations/symm_maps lengths raise ValueError.""" rots = [torch.eye(3).unsqueeze(0)] smaps = [torch.zeros(1, 1, dtype=torch.long), torch.zeros(1, 2, dtype=torch.long)] with pytest.raises(ValueError, match="length mismatch"): FixSymmetry(rots, smaps) - # reference_cells length must match n_systems - smaps_ok = [torch.zeros(1, 1, dtype=torch.long)] - with pytest.raises(ValueError, match="reference_cells length"): - FixSymmetry(rots, smaps_ok, reference_cells=[torch.eye(3), torch.eye(3)]) @pytest.mark.parametrize("method", ["adjust_positions", "adjust_cell"]) def test_adjust_skipped_when_disabled(self, method: str) -> None: @@ -665,44 +663,144 @@ def test_noisy_model_preserves_symmetry_with_constraint( assert result["initial_spacegroups"][0] == 229 assert result["final_spacegroups"][0] == 229 - def test_cumulative_strain_clamp_direct(self) -> None: - """adjust_cell clamps deformation when cumulative strain exceeds limit. - Directly tests the clamping mechanism by repeatedly applying small - cell deformations that individually pass the per-step check (< 0.25) - but cumulatively exceed max_cumulative_strain. Verifies: - 1. The cell doesn't drift beyond the strain envelope - 2. Symmetry is preserved after many small steps - """ - state = ts.io.atoms_to_state(make_structure("fcc", repeats=1), DEVICE, DTYPE) +class TestFixSymmetryCellPositionsResync: + """Tests that cell_positions stays consistent with the actual cell after + optimizer steps with FixSymmetry. These would have caught the + cell_positions desync bug (Fix 1) and the batching discrepancy. + """ + + @pytest.mark.parametrize( + "optimizer", + [ + pytest.param((fire_init, fire_step), id="fire"), + pytest.param((lbfgs_init, lbfgs_step), id="lbfgs"), + ], + ) + def test_cell_positions_consistent_after_step( + self, + model: LennardJonesModel, + optimizer: tuple, + ) -> None: + """cell_positions matches actual cell after one step with FixSymmetry.""" + state = ts.io.atoms_to_state(make_structure("hcp"), DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - constraint.max_cumulative_strain = 0.15 - assert constraint.reference_cells is not None - ref_cell = constraint.reference_cells[0].clone() + state.constraints = [constraint] + state.cell = state.cell * 0.95 + state.positions = state.positions * 0.95 - # Apply 20 small deformations (each ~5% along one axis) - # Total would be ~100% without clamping, well over the 0.15 limit - identity = torch.eye(3, dtype=DTYPE) - for _ in range(20): - # Stretch c-axis by 5% (cubic symmetrization isotropizes this) - stretch = identity.clone() - stretch[2, 2] = 1.05 - new_cell = (state.row_vector_cell[0] @ stretch).mT.unsqueeze(0) - constraint.adjust_cell(state, new_cell) - state.cell = new_cell - - # Cumulative strain must be clamped to the limit - final_cell = state.row_vector_cell[0] - cumulative = torch.linalg.solve(ref_cell, final_cell) - identity - max_strain = torch.abs(cumulative).max().item() - assert max_strain <= constraint.max_cumulative_strain + 1e-6, ( - f"Strain {max_strain:.4f} exceeded {constraint.max_cumulative_strain}" + init_fn, step_fn = optimizer + opt_state = init_fn(state, model, cell_filter=ts.CellFilter.frechet) + + step_kwargs = {} + if init_fn is fire_init: + step_kwargs["fire_flavor"] = "ase_fire" + opt_state = step_fn(state=opt_state, model=model, **step_kwargs) + + # Recompute expected cell_positions from the actual cell + cur_dg = deform_grad(opt_state.reference_cell.mT, opt_state.row_vector_cell) + expected_cp = ( + ts.math.matrix_log_33(cur_dg, sim_dtype=opt_state.dtype) + * opt_state.cell_factor.view(opt_state.n_systems, 1, 1) + ) + assert torch.allclose(opt_state.cell_positions, expected_cp, atol=1e-5), ( + f"cell_positions desynced from actual cell: " + f"max diff = {(opt_state.cell_positions - expected_cp).abs().max().item():.2e}" ) - # Without clamping, 1.05^20 = 2.65x → strain ~1.65, far over 0.15 - # Verify it's actually being clamped (not just small steps) - assert max_strain > 0.10, f"Strain {max_strain:.4f} suspiciously low" + @pytest.mark.parametrize( + "optimizer", + [ + pytest.param((fire_init, fire_step), id="fire"), + pytest.param((lbfgs_init, lbfgs_step), id="lbfgs"), + ], + ) + def test_optimizer_sym_batch1_matches_batchN( + self, + model: LennardJonesModel, + optimizer: tuple, + ) -> None: + """Batch=1 and batch=2 give the same per-system trajectory with FixSymmetry.""" + atoms = make_structure("hcp") + n_check_steps = 10 + init_fn, step_fn = optimizer + step_kwargs = {} + if init_fn is fire_init: + step_kwargs["fire_flavor"] = "ase_fire" + + # Batch=1 run + state1 = ts.io.atoms_to_state(atoms, DEVICE, DTYPE) + c1 = FixSymmetry.from_state(state1, symprec=SYMPREC) + state1.constraints = [c1] + state1.cell = state1.cell * 0.95 + state1.positions = state1.positions * 0.95 + s1 = init_fn(state1, model, cell_filter=ts.CellFilter.frechet) + energies_1 = [s1.energy.item()] + for _ in range(n_check_steps): + s1 = step_fn(state=s1, model=model, **step_kwargs) + energies_1.append(s1.energy.item()) + + # Batch=2 run (two copies of the same structure) + state2 = ts.io.atoms_to_state([atoms, atoms], DEVICE, DTYPE) + c2 = FixSymmetry.from_state(state2, symprec=SYMPREC) + state2.constraints = [c2] + state2.cell = state2.cell * 0.95 + state2.positions = state2.positions * 0.95 + s2 = init_fn(state2, model, cell_filter=ts.CellFilter.frechet) + energies_2_sys0 = [s2.energy[0].item()] + for _ in range(n_check_steps): + s2 = step_fn(state=s2, model=model, **step_kwargs) + energies_2_sys0.append(s2.energy[0].item()) + + # Per-step energies should match + for step, (e1, e2) in enumerate(zip(energies_1, energies_2_sys0)): + assert abs(e1 - e2) < 1e-4, ( + f"Energy diverged at step {step}: batch=1 {e1:.6f} vs " + f"batch=2[sys0] {e2:.6f} (diff={abs(e1-e2):.2e})" + ) + + @pytest.mark.parametrize( + "optimizer", + [ + pytest.param(ts.Optimizer.fire, id="fire"), + pytest.param(ts.Optimizer.lbfgs, id="lbfgs"), + ], + ) + def test_optimizer_sym_converges( + self, + noisy_lj_model: NoisyModelWrapper, + optimizer: ts.Optimizer, + ) -> None: + """Optimizer with FixSymmetry + Frechet converges on anisotropically strained HCP. - # Symmetry should still be detectable - datasets = get_symmetry_datasets(state, symprec=SYMPREC) - assert datasets[0].number == SPACEGROUPS["fcc"] + Uses HCP with anisotropic strain (a-axis compressed, c-axis stretched) + so the cell actively wants to change shape under symmetry constraints. + Asserts the optimizer converges within MAX_STEPS (not just preserves symmetry). + """ + state = ts.io.atoms_to_state(make_structure("hcp"), DEVICE, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + state.constraints = [constraint] + # Anisotropic strain: compress a/b by 10%, stretch c by 10% + strain = torch.eye(3, dtype=DTYPE) + strain[0, 0] = 0.90 + strain[1, 1] = 0.90 + strain[2, 2] = 1.10 + state.cell = torch.bmm(state.cell, strain.unsqueeze(0).expand_as(state.cell)) + state.positions = state.positions @ strain + + convergence_fn = ts.generate_force_convergence_fn( + force_tol=0.01, include_cell_forces=True, + ) + final_state = ts.optimize( + system=state, + model=noisy_lj_model, + optimizer=optimizer, + convergence_fn=convergence_fn, + init_kwargs={"cell_filter": ts.CellFilter.frechet}, + max_steps=MAX_STEPS, + steps_between_swaps=1, + ) + fmax = ts.system_wise_max_force(final_state).item() + cell_fmax = final_state.cell_forces.norm(dim=2).max().item() + assert fmax < 0.01, f"Atomic forces not converged: fmax={fmax:.4f}" + assert cell_fmax < 0.01, f"Cell forces not converged: cell_fmax={cell_fmax:.4f}" diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index ec4ece699..c71b65412 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -10,6 +10,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import BFGSState, FireFlavor, FireState, LBFGSState, OptimState +from torch_sim.optimizers.cell_filters import CellLBFGSState, deform_grad from torch_sim.state import SimState @@ -1492,3 +1493,55 @@ def test_optimizer_preserves_charge_spin( assert torch.allclose(opt_state.charge, original_charge) assert torch.allclose(opt_state.spin, original_spin) + + +def test_lbfgs_prev_cell_positions_stored_before_step( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """prev_cell_positions captures start-of-step, prev_positions use adjusted frame. + """ + from ase.build import bulk + + from torch_sim.constraints import FixSymmetry + + atoms = bulk("Ti", "hcp", a=2.95, c=4.68).repeat([2, 2, 2]) + state = ts.io.atoms_to_state(atoms, lj_model.device, lj_model.dtype) + constraint = FixSymmetry.from_state(state, symprec=0.01) + state.constraints = [constraint] + state.cell = state.cell * 0.95 + state.positions = state.positions * 0.95 + + opt_state = ts.lbfgs_init( + state, lj_model, cell_filter=ts.CellFilter.frechet + ) + assert isinstance(opt_state, CellLBFGSState) + + # Save cell_positions BEFORE the step + cell_pos_before = opt_state.cell_positions.clone() + + # Run one step + opt_state = ts.lbfgs_step(state=opt_state, model=lj_model) + + # prev_cell_positions should equal the pre-step value (not the post-resync value) + assert torch.allclose(opt_state.prev_cell_positions, cell_pos_before, atol=1e-6), ( + "prev_cell_positions should capture start-of-step, not post-resync. " + f"max diff from pre-step = {(opt_state.prev_cell_positions - cell_pos_before).abs().max():.2e}, " + f"max diff from current = {(opt_state.prev_cell_positions - opt_state.cell_positions).abs().max():.2e}" + ) + + # prev_cell_positions should NOT equal the current (post-step) cell_positions + # (unless the step was zero, which shouldn't happen on a compressed structure) + assert not torch.allclose(opt_state.prev_cell_positions, opt_state.cell_positions, atol=1e-6), ( + "prev_cell_positions equals current cell_positions — s_new_cell would be zero" + ) + + # prev_positions should be in the adjusted (post-adjust_cell) frame + cur_dg = deform_grad(opt_state.reference_cell.mT, opt_state.row_vector_cell) + expected_prev = torch.linalg.solve( + cur_dg[opt_state.system_idx], + opt_state.positions.unsqueeze(-1), + ).squeeze(-1) + assert torch.allclose(opt_state.prev_positions, expected_prev, atol=1e-5), ( + "prev_positions should be fractional coords in the adjusted cell frame. " + f"max diff = {(opt_state.prev_positions - expected_prev).abs().max():.2e}" + ) From 54fffe85cb56b9becef10117def5e2344ddb305a Mon Sep 17 00:00:00 2001 From: Pascal Salzbrenner Date: Fri, 24 Apr 2026 14:23:46 +0000 Subject: [PATCH 4/7] Remove references to reference_cells --- tests/test_constraints.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 82ce8b877..df228e705 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1134,8 +1134,7 @@ def test_fix_symmetry_system_idx_remapped_on_reordered_slice( mixed_double_sim_state: ts.SimState, ) -> None: """Slicing with reversed system order must remap FixSymmetry so each - system's rotations/symm_maps/reference_cells stay paired with the - correct output system. + system's rotations/symm_maps stay paired with the correct output system. """ state = mixed_double_sim_state # 2 systems @@ -1149,15 +1148,11 @@ def test_fix_symmetry_system_idx_remapped_on_reordered_slice( smap0 = torch.arange(n0).unsqueeze(0) # (1, n0) smap1 = torch.arange(n1).unsqueeze(0) # (1, n1) - ref0 = state.row_vector_cell[0].clone() - ref1 = state.row_vector_cell[1].clone() - state.constraints = [ FixSymmetry( rotations=[rot0, rot1], symm_maps=[smap0, smap1], system_idx=torch.tensor([0, 1]), - reference_cells=[ref0, ref1], ) ] @@ -1177,13 +1172,10 @@ def test_fix_symmetry_system_idx_remapped_on_reordered_slice( # Output system 0 = old system 1 → should use rot1 ci_for_output0 = si_to_ci[0] assert torch.equal(c.rotations[ci_for_output0], rot1) - assert c.reference_cells is not None - assert torch.equal(c.reference_cells[ci_for_output0], ref1) # Output system 1 = old system 0 → should use rot0 ci_for_output1 = si_to_ci[1] assert torch.equal(c.rotations[ci_for_output1], rot0) - assert torch.equal(c.reference_cells[ci_for_output1], ref0) def test_fix_com_system_idx_remapped_on_reordered_slice( @@ -1247,10 +1239,9 @@ def test_fix_com_dtype_propagation(self, ar_supercell_sim_state: ts.SimState) -> @pytest.mark.parametrize("target_dtype", [torch.float32, torch.float64]) def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None: - """FixSymmetry rotations and reference_cells must follow dtype changes.""" + """FixSymmetry rotations must follow dtype changes.""" rotations = [torch.eye(3, dtype=torch.float64).unsqueeze(0)] symm_maps = [torch.zeros(1, 2, dtype=torch.long)] - ref_cells = [torch.eye(3, dtype=torch.float64)] state = ts.SimState( positions=torch.zeros(2, 3, dtype=torch.float64), @@ -1260,14 +1251,12 @@ def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None atomic_numbers=torch.tensor([14, 14]), system_idx=torch.zeros(2, dtype=torch.long), ) - state.constraints = [FixSymmetry(rotations, symm_maps, reference_cells=ref_cells)] + state.constraints = [FixSymmetry(rotations, symm_maps)] new_state = state.to(dtype=target_dtype) c = new_state.constraints[0] assert isinstance(c, FixSymmetry) assert c.rotations[0].dtype == target_dtype - assert c.reference_cells is not None - assert c.reference_cells[0].dtype == target_dtype # integer symm_maps must stay long assert c.symm_maps[0].dtype == torch.long # original constraint unchanged From b3b28215bafad951917b0d8f4df127a5d6844e7c Mon Sep 17 00:00:00 2001 From: Pascal Salzbrenner Date: Fri, 24 Apr 2026 14:38:33 +0000 Subject: [PATCH 5/7] Apply TorchSim formatting with pre-commit checks Co-authored-by: Copilot --- tests/test_fix_symmetry.py | 20 ++++++++++---------- tests/test_optimizers.py | 21 ++++++++------------- torch_sim/optimizers/fire.py | 12 +++--------- torch_sim/optimizers/lbfgs.py | 7 ++----- 4 files changed, 23 insertions(+), 37 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index bf1b6222c..9155a4cbc 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -15,7 +15,7 @@ from torch_sim.constraints import FixCom, FixSymmetry from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.optimizers.cell_filters import CellLBFGSState, deform_grad +from torch_sim.optimizers.cell_filters import deform_grad from torch_sim.optimizers.fire import fire_init, fire_step from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step from torch_sim.symmetrize import get_symmetry_datasets @@ -699,13 +699,12 @@ def test_cell_positions_consistent_after_step( # Recompute expected cell_positions from the actual cell cur_dg = deform_grad(opt_state.reference_cell.mT, opt_state.row_vector_cell) - expected_cp = ( - ts.math.matrix_log_33(cur_dg, sim_dtype=opt_state.dtype) - * opt_state.cell_factor.view(opt_state.n_systems, 1, 1) - ) + expected_cp = ts.math.matrix_log_33( + cur_dg, sim_dtype=opt_state.dtype + ) * opt_state.cell_factor.view(opt_state.n_systems, 1, 1) assert torch.allclose(opt_state.cell_positions, expected_cp, atol=1e-5), ( f"cell_positions desynced from actual cell: " - f"max diff = {(opt_state.cell_positions - expected_cp).abs().max().item():.2e}" + f"max diff = {(opt_state.cell_positions - expected_cp).abs().max().item():.2e}" # NOQA: E501 ) @pytest.mark.parametrize( @@ -715,7 +714,7 @@ def test_cell_positions_consistent_after_step( pytest.param((lbfgs_init, lbfgs_step), id="lbfgs"), ], ) - def test_optimizer_sym_batch1_matches_batchN( + def test_optimizer_sym_batch1_matches_batch_n( self, model: LennardJonesModel, optimizer: tuple, @@ -753,10 +752,10 @@ def test_optimizer_sym_batch1_matches_batchN( energies_2_sys0.append(s2.energy[0].item()) # Per-step energies should match - for step, (e1, e2) in enumerate(zip(energies_1, energies_2_sys0)): + for step, (e1, e2) in enumerate(zip(energies_1, energies_2_sys0, strict=True)): assert abs(e1 - e2) < 1e-4, ( f"Energy diverged at step {step}: batch=1 {e1:.6f} vs " - f"batch=2[sys0] {e2:.6f} (diff={abs(e1-e2):.2e})" + f"batch=2[sys0] {e2:.6f} (diff={abs(e1 - e2):.2e})" ) @pytest.mark.parametrize( @@ -789,7 +788,8 @@ def test_optimizer_sym_converges( state.positions = state.positions @ strain convergence_fn = ts.generate_force_convergence_fn( - force_tol=0.01, include_cell_forces=True, + force_tol=0.01, + include_cell_forces=True, ) final_state = ts.optimize( system=state, diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index c71b65412..3f419f688 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1495,11 +1495,8 @@ def test_optimizer_preserves_charge_spin( assert torch.allclose(opt_state.spin, original_spin) -def test_lbfgs_prev_cell_positions_stored_before_step( - ar_supercell_sim_state: SimState, lj_model: ModelInterface -) -> None: - """prev_cell_positions captures start-of-step, prev_positions use adjusted frame. - """ +def test_lbfgs_prev_cell_positions_stored_before_step(lj_model: ModelInterface) -> None: + """prev_cell_positions captures start-of-step, prev_positions use adjusted frame.""" from ase.build import bulk from torch_sim.constraints import FixSymmetry @@ -1511,9 +1508,7 @@ def test_lbfgs_prev_cell_positions_stored_before_step( state.cell = state.cell * 0.95 state.positions = state.positions * 0.95 - opt_state = ts.lbfgs_init( - state, lj_model, cell_filter=ts.CellFilter.frechet - ) + opt_state = ts.lbfgs_init(state, lj_model, cell_filter=ts.CellFilter.frechet) assert isinstance(opt_state, CellLBFGSState) # Save cell_positions BEFORE the step @@ -1525,15 +1520,15 @@ def test_lbfgs_prev_cell_positions_stored_before_step( # prev_cell_positions should equal the pre-step value (not the post-resync value) assert torch.allclose(opt_state.prev_cell_positions, cell_pos_before, atol=1e-6), ( "prev_cell_positions should capture start-of-step, not post-resync. " - f"max diff from pre-step = {(opt_state.prev_cell_positions - cell_pos_before).abs().max():.2e}, " - f"max diff from current = {(opt_state.prev_cell_positions - opt_state.cell_positions).abs().max():.2e}" + f"max diff from pre-step = {(opt_state.prev_cell_positions - cell_pos_before).abs().max():.2e}, " # noqa: E501 + f"max diff from current = {(opt_state.prev_cell_positions - opt_state.cell_positions).abs().max():.2e}" # noqa: E501 ) # prev_cell_positions should NOT equal the current (post-step) cell_positions # (unless the step was zero, which shouldn't happen on a compressed structure) - assert not torch.allclose(opt_state.prev_cell_positions, opt_state.cell_positions, atol=1e-6), ( - "prev_cell_positions equals current cell_positions — s_new_cell would be zero" - ) + assert not torch.allclose( + opt_state.prev_cell_positions, opt_state.cell_positions, atol=1e-6 + ), "prev_cell_positions equals current cell_positions — s_new_cell would be zero" # prev_positions should be in the adjusted (post-adjust_cell) frame cur_dg = deform_grad(opt_state.reference_cell.mT, opt_state.row_vector_cell) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 66d8b83fc..b4e23b95e 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -458,19 +458,13 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 state.reference_cell.mT, state.row_vector_cell ) if is_frechet: - cell_factor_reshaped = state.cell_factor.view( - state.n_systems, 1, 1 - ) + cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) state.cell_positions = ( - tsm.matrix_log_33( - adjusted_deform_grad, sim_dtype=state.dtype - ) + tsm.matrix_log_33(adjusted_deform_grad, sim_dtype=state.dtype) * cell_factor_reshaped ) else: - cell_factor_expanded = state.cell_factor.expand( - state.n_systems, 3, 1 - ) + cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) state.cell_positions = ( adjusted_deform_grad.reshape(state.n_systems, 3, 3) * cell_factor_expanded diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 482fde6cc..c6cdb01a5 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -513,16 +513,13 @@ def lbfgs_step( # noqa: PLR0915, C901 if is_frechet: cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1) state.cell_positions = ( - ts.math.matrix_log_33( - adjusted_deform_grad, sim_dtype=state.dtype - ) + ts.math.matrix_log_33(adjusted_deform_grad, sim_dtype=state.dtype) * cell_factor_reshaped ) else: cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) state.cell_positions = ( - adjusted_deform_grad.reshape(n_systems, 3, 3) - * cell_factor_expanded + adjusted_deform_grad.reshape(n_systems, 3, 3) * cell_factor_expanded ) # Store prev_positions/prev_forces in the ADJUSTED cell's frame so From 9c18addbf979db6b5628cd2b3042a63c7e552bec Mon Sep 17 00:00:00 2001 From: Pascal Salzbrenner Date: Tue, 28 Apr 2026 08:40:44 +0000 Subject: [PATCH 6/7] Docstring cleanup --- tests/test_fix_symmetry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 9155a4cbc..db7e09d7d 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -666,8 +666,8 @@ def test_noisy_model_preserves_symmetry_with_constraint( class TestFixSymmetryCellPositionsResync: """Tests that cell_positions stays consistent with the actual cell after - optimizer steps with FixSymmetry. These would have caught the - cell_positions desync bug (Fix 1) and the batching discrepancy. + optimizer steps with FixSymmetry. These catch cell_positions desyncs and + batching discrepancies. """ @pytest.mark.parametrize( From 96d84c5baa9463f0ffaf9719d21e98eb3a7786b8 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 4 May 2026 10:39:19 -0400 Subject: [PATCH 7/7] remove hard coded 0.25 --- torch_sim/constraints.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 66231b26f..02bc03043 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -913,18 +913,22 @@ def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: rots = self.rotations[ci].to(dtype=dtype) stress[si] = symmetrize_rank2(state.row_vector_cell[si], stress[si], rots) - def adjust_cell(self, state: SimState, cell: torch.Tensor) -> None: + def adjust_cell( + self, state: SimState, cell: torch.Tensor, max_delta_component: float = 0.25 + ) -> None: """Symmetrize cell deformation gradient in-place. Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``. - Per-step deformation is clamped at 0.25 to avoid ill-conditioned - symmetrization, matching the ASE FixSymmetry behaviour. + Per-step deformation is clamped at max_delta_component to avoid + ill-conditioned symmetrization, matching the ASE FixSymmetry behaviour. Args: state: Current simulation state. cell: Cell tensor (n_systems, 3, 3) in column vector convention. + max_delta_component: Maximum component of the per-step deformation + gradient to allow. Raises: RuntimeError: If deformation gradient contains NaN or Inf. @@ -948,8 +952,8 @@ def adjust_cell(self, state: SimState, cell: torch.Tensor) -> None: f"FixSymmetry: deformation gradient is {max_delta}, " f"cell may be singular or ill-conditioned." ) - if max_delta > 0.25: - deform_delta = deform_delta * (0.25 / max_delta) + if max_delta > max_delta_component: + deform_delta = deform_delta * (max_delta_component / max_delta) # Symmetrize the per-step deformation rots = self.rotations[ci].to(dtype=state.dtype)