Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,7 @@ def test_fix_symmetry_system_idx_remapped_on_reordered_slice(
mixed_double_sim_state: ts.SimState,
) -> None:
"""Slicing with reversed system order must remap FixSymmetry so each
system's rotations/symm_maps/reference_cells stay paired with the
correct output system.
system's rotations/symm_maps stay paired with the correct output system.
"""
state = mixed_double_sim_state # 2 systems

Expand All @@ -1149,15 +1148,11 @@ def test_fix_symmetry_system_idx_remapped_on_reordered_slice(
smap0 = torch.arange(n0).unsqueeze(0) # (1, n0)
smap1 = torch.arange(n1).unsqueeze(0) # (1, n1)

ref0 = state.row_vector_cell[0].clone()
ref1 = state.row_vector_cell[1].clone()

state.constraints = [
FixSymmetry(
rotations=[rot0, rot1],
symm_maps=[smap0, smap1],
system_idx=torch.tensor([0, 1]),
reference_cells=[ref0, ref1],
)
]

Expand All @@ -1177,13 +1172,10 @@ def test_fix_symmetry_system_idx_remapped_on_reordered_slice(
# Output system 0 = old system 1 → should use rot1
ci_for_output0 = si_to_ci[0]
assert torch.equal(c.rotations[ci_for_output0], rot1)
assert c.reference_cells is not None
assert torch.equal(c.reference_cells[ci_for_output0], ref1)

# Output system 1 = old system 0 → should use rot0
ci_for_output1 = si_to_ci[1]
assert torch.equal(c.rotations[ci_for_output1], rot0)
assert torch.equal(c.reference_cells[ci_for_output1], ref0)


def test_fix_com_system_idx_remapped_on_reordered_slice(
Expand Down Expand Up @@ -1247,10 +1239,9 @@ def test_fix_com_dtype_propagation(self, ar_supercell_sim_state: ts.SimState) ->

@pytest.mark.parametrize("target_dtype", [torch.float32, torch.float64])
def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None:
"""FixSymmetry rotations and reference_cells must follow dtype changes."""
"""FixSymmetry rotations must follow dtype changes."""
rotations = [torch.eye(3, dtype=torch.float64).unsqueeze(0)]
symm_maps = [torch.zeros(1, 2, dtype=torch.long)]
ref_cells = [torch.eye(3, dtype=torch.float64)]

state = ts.SimState(
positions=torch.zeros(2, 3, dtype=torch.float64),
Expand All @@ -1260,14 +1251,12 @@ def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None
atomic_numbers=torch.tensor([14, 14]),
system_idx=torch.zeros(2, dtype=torch.long),
)
state.constraints = [FixSymmetry(rotations, symm_maps, reference_cells=ref_cells)]
state.constraints = [FixSymmetry(rotations, symm_maps)]

new_state = state.to(dtype=target_dtype)
c = new_state.constraints[0]
assert isinstance(c, FixSymmetry)
assert c.rotations[0].dtype == target_dtype
assert c.reference_cells is not None
assert c.reference_cells[0].dtype == target_dtype
# integer symm_maps must stay long
assert c.symm_maps[0].dtype == torch.long
# original constraint unchanged
Expand Down
184 changes: 141 additions & 43 deletions tests/test_fix_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from torch_sim.constraints import FixCom, FixSymmetry
from torch_sim.models.interface import ModelInterface
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.optimizers.cell_filters import deform_grad
from torch_sim.optimizers.fire import fire_init, fire_step
from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step
from torch_sim.symmetrize import get_symmetry_datasets


Expand Down Expand Up @@ -285,9 +288,8 @@ def test_large_deformation_clamped(self) -> None:
assert not torch.allclose(new_cell, orig_cell * 1.5, atol=1e-6)
# Per-step clamp limits single-step strain to 0.25
identity = torch.eye(3, dtype=DTYPE)
assert constraint.reference_cells is not None
ref_cell = constraint.reference_cells[0]
strain = torch.linalg.solve(ref_cell, new_cell[0].mT) - identity
cur_cell = state.row_vector_cell[0]
strain = torch.linalg.solve(cur_cell, new_cell[0].mT) - identity
assert torch.abs(strain).max().item() <= 0.25 + 1e-6

def test_nan_deformation_raises(self) -> None:
Expand All @@ -300,15 +302,11 @@ def test_nan_deformation_raises(self) -> None:
constraint.adjust_cell(state, new_cell)

def test_init_mismatched_lengths_raises(self) -> None:
"""Mismatched rotations/symm_maps/reference_cells lengths raise ValueError."""
"""Mismatched rotations/symm_maps lengths raise ValueError."""
rots = [torch.eye(3).unsqueeze(0)]
smaps = [torch.zeros(1, 1, dtype=torch.long), torch.zeros(1, 2, dtype=torch.long)]
with pytest.raises(ValueError, match="length mismatch"):
FixSymmetry(rots, smaps)
# reference_cells length must match n_systems
smaps_ok = [torch.zeros(1, 1, dtype=torch.long)]
with pytest.raises(ValueError, match="reference_cells length"):
FixSymmetry(rots, smaps_ok, reference_cells=[torch.eye(3), torch.eye(3)])

@pytest.mark.parametrize("method", ["adjust_positions", "adjust_cell"])
def test_adjust_skipped_when_disabled(self, method: str) -> None:
Expand Down Expand Up @@ -665,44 +663,144 @@ def test_noisy_model_preserves_symmetry_with_constraint(
assert result["initial_spacegroups"][0] == 229
assert result["final_spacegroups"][0] == 229

def test_cumulative_strain_clamp_direct(self) -> None:
"""adjust_cell clamps deformation when cumulative strain exceeds limit.

Directly tests the clamping mechanism by repeatedly applying small
cell deformations that individually pass the per-step check (< 0.25)
but cumulatively exceed max_cumulative_strain. Verifies:
1. The cell doesn't drift beyond the strain envelope
2. Symmetry is preserved after many small steps
"""
state = ts.io.atoms_to_state(make_structure("fcc", repeats=1), DEVICE, DTYPE)
class TestFixSymmetryCellPositionsResync:
"""Tests that cell_positions stays consistent with the actual cell after
optimizer steps with FixSymmetry. These catch cell_positions desyncs and
batching discrepancies.
"""

@pytest.mark.parametrize(
"optimizer",
[
pytest.param((fire_init, fire_step), id="fire"),
pytest.param((lbfgs_init, lbfgs_step), id="lbfgs"),
],
)
def test_cell_positions_consistent_after_step(
self,
model: LennardJonesModel,
optimizer: tuple,
) -> None:
"""cell_positions matches actual cell after one step with FixSymmetry."""
state = ts.io.atoms_to_state(make_structure("hcp"), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
constraint.max_cumulative_strain = 0.15
assert constraint.reference_cells is not None
ref_cell = constraint.reference_cells[0].clone()
state.constraints = [constraint]
state.cell = state.cell * 0.95
state.positions = state.positions * 0.95

# Apply 20 small deformations (each ~5% along one axis)
# Total would be ~100% without clamping, well over the 0.15 limit
identity = torch.eye(3, dtype=DTYPE)
for _ in range(20):
# Stretch c-axis by 5% (cubic symmetrization isotropizes this)
stretch = identity.clone()
stretch[2, 2] = 1.05
new_cell = (state.row_vector_cell[0] @ stretch).mT.unsqueeze(0)
constraint.adjust_cell(state, new_cell)
state.cell = new_cell

# Cumulative strain must be clamped to the limit
final_cell = state.row_vector_cell[0]
cumulative = torch.linalg.solve(ref_cell, final_cell) - identity
max_strain = torch.abs(cumulative).max().item()
assert max_strain <= constraint.max_cumulative_strain + 1e-6, (
f"Strain {max_strain:.4f} exceeded {constraint.max_cumulative_strain}"
init_fn, step_fn = optimizer
opt_state = init_fn(state, model, cell_filter=ts.CellFilter.frechet)

step_kwargs = {}
if init_fn is fire_init:
step_kwargs["fire_flavor"] = "ase_fire"
opt_state = step_fn(state=opt_state, model=model, **step_kwargs)

# Recompute expected cell_positions from the actual cell
cur_dg = deform_grad(opt_state.reference_cell.mT, opt_state.row_vector_cell)
expected_cp = ts.math.matrix_log_33(
cur_dg, sim_dtype=opt_state.dtype
) * opt_state.cell_factor.view(opt_state.n_systems, 1, 1)
assert torch.allclose(opt_state.cell_positions, expected_cp, atol=1e-5), (
f"cell_positions desynced from actual cell: "
f"max diff = {(opt_state.cell_positions - expected_cp).abs().max().item():.2e}" # NOQA: E501
)

# Without clamping, 1.05^20 = 2.65x → strain ~1.65, far over 0.15
# Verify it's actually being clamped (not just small steps)
assert max_strain > 0.10, f"Strain {max_strain:.4f} suspiciously low"
@pytest.mark.parametrize(
"optimizer",
[
pytest.param((fire_init, fire_step), id="fire"),
pytest.param((lbfgs_init, lbfgs_step), id="lbfgs"),
],
)
def test_optimizer_sym_batch1_matches_batch_n(
self,
model: LennardJonesModel,
optimizer: tuple,
) -> None:
"""Batch=1 and batch=2 give the same per-system trajectory with FixSymmetry."""
atoms = make_structure("hcp")
n_check_steps = 10
init_fn, step_fn = optimizer
step_kwargs = {}
if init_fn is fire_init:
step_kwargs["fire_flavor"] = "ase_fire"

# Batch=1 run
state1 = ts.io.atoms_to_state(atoms, DEVICE, DTYPE)
c1 = FixSymmetry.from_state(state1, symprec=SYMPREC)
state1.constraints = [c1]
state1.cell = state1.cell * 0.95
state1.positions = state1.positions * 0.95
s1 = init_fn(state1, model, cell_filter=ts.CellFilter.frechet)
energies_1 = [s1.energy.item()]
for _ in range(n_check_steps):
s1 = step_fn(state=s1, model=model, **step_kwargs)
energies_1.append(s1.energy.item())

# Batch=2 run (two copies of the same structure)
state2 = ts.io.atoms_to_state([atoms, atoms], DEVICE, DTYPE)
c2 = FixSymmetry.from_state(state2, symprec=SYMPREC)
state2.constraints = [c2]
state2.cell = state2.cell * 0.95
state2.positions = state2.positions * 0.95
s2 = init_fn(state2, model, cell_filter=ts.CellFilter.frechet)
energies_2_sys0 = [s2.energy[0].item()]
for _ in range(n_check_steps):
s2 = step_fn(state=s2, model=model, **step_kwargs)
energies_2_sys0.append(s2.energy[0].item())

# Per-step energies should match
for step, (e1, e2) in enumerate(zip(energies_1, energies_2_sys0, strict=True)):
assert abs(e1 - e2) < 1e-4, (
f"Energy diverged at step {step}: batch=1 {e1:.6f} vs "
f"batch=2[sys0] {e2:.6f} (diff={abs(e1 - e2):.2e})"
)

# Symmetry should still be detectable
datasets = get_symmetry_datasets(state, symprec=SYMPREC)
assert datasets[0].number == SPACEGROUPS["fcc"]
@pytest.mark.parametrize(
"optimizer",
[
pytest.param(ts.Optimizer.fire, id="fire"),
pytest.param(ts.Optimizer.lbfgs, id="lbfgs"),
],
)
def test_optimizer_sym_converges(
self,
noisy_lj_model: NoisyModelWrapper,
optimizer: ts.Optimizer,
) -> None:
"""Optimizer with FixSymmetry + Frechet converges on anisotropically strained HCP.

Uses HCP with anisotropic strain (a-axis compressed, c-axis stretched)
so the cell actively wants to change shape under symmetry constraints.
Asserts the optimizer converges within MAX_STEPS (not just preserves symmetry).
"""
state = ts.io.atoms_to_state(make_structure("hcp"), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
state.constraints = [constraint]
# Anisotropic strain: compress a/b by 10%, stretch c by 10%
strain = torch.eye(3, dtype=DTYPE)
strain[0, 0] = 0.90
strain[1, 1] = 0.90
strain[2, 2] = 1.10
state.cell = torch.bmm(state.cell, strain.unsqueeze(0).expand_as(state.cell))
state.positions = state.positions @ strain

convergence_fn = ts.generate_force_convergence_fn(
force_tol=0.01,
include_cell_forces=True,
)
final_state = ts.optimize(
system=state,
model=noisy_lj_model,
optimizer=optimizer,
convergence_fn=convergence_fn,
init_kwargs={"cell_filter": ts.CellFilter.frechet},
max_steps=MAX_STEPS,
steps_between_swaps=1,
)
fmax = ts.system_wise_max_force(final_state).item()
cell_fmax = final_state.cell_forces.norm(dim=2).max().item()
assert fmax < 0.01, f"Atomic forces not converged: fmax={fmax:.4f}"
assert cell_fmax < 0.01, f"Cell forces not converged: cell_fmax={cell_fmax:.4f}"
48 changes: 48 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch_sim as ts
from torch_sim.models.interface import ModelInterface
from torch_sim.optimizers import BFGSState, FireFlavor, FireState, LBFGSState, OptimState
from torch_sim.optimizers.cell_filters import CellLBFGSState, deform_grad
from torch_sim.state import SimState


Expand Down Expand Up @@ -1492,3 +1493,50 @@ def test_optimizer_preserves_charge_spin(

assert torch.allclose(opt_state.charge, original_charge)
assert torch.allclose(opt_state.spin, original_spin)


def test_lbfgs_prev_cell_positions_stored_before_step(lj_model: ModelInterface) -> None:
"""prev_cell_positions captures start-of-step, prev_positions use adjusted frame."""
from ase.build import bulk

from torch_sim.constraints import FixSymmetry

atoms = bulk("Ti", "hcp", a=2.95, c=4.68).repeat([2, 2, 2])
state = ts.io.atoms_to_state(atoms, lj_model.device, lj_model.dtype)
constraint = FixSymmetry.from_state(state, symprec=0.01)
state.constraints = [constraint]
state.cell = state.cell * 0.95
state.positions = state.positions * 0.95

opt_state = ts.lbfgs_init(state, lj_model, cell_filter=ts.CellFilter.frechet)
assert isinstance(opt_state, CellLBFGSState)

# Save cell_positions BEFORE the step
cell_pos_before = opt_state.cell_positions.clone()

# Run one step
opt_state = ts.lbfgs_step(state=opt_state, model=lj_model)

# prev_cell_positions should equal the pre-step value (not the post-resync value)
assert torch.allclose(opt_state.prev_cell_positions, cell_pos_before, atol=1e-6), (
"prev_cell_positions should capture start-of-step, not post-resync. "
f"max diff from pre-step = {(opt_state.prev_cell_positions - cell_pos_before).abs().max():.2e}, " # noqa: E501
f"max diff from current = {(opt_state.prev_cell_positions - opt_state.cell_positions).abs().max():.2e}" # noqa: E501
)

# prev_cell_positions should NOT equal the current (post-step) cell_positions
# (unless the step was zero, which shouldn't happen on a compressed structure)
assert not torch.allclose(
opt_state.prev_cell_positions, opt_state.cell_positions, atol=1e-6
), "prev_cell_positions equals current cell_positions — s_new_cell would be zero"

# prev_positions should be in the adjusted (post-adjust_cell) frame
cur_dg = deform_grad(opt_state.reference_cell.mT, opt_state.row_vector_cell)
expected_prev = torch.linalg.solve(
cur_dg[opt_state.system_idx],
opt_state.positions.unsqueeze(-1),
).squeeze(-1)
assert torch.allclose(opt_state.prev_positions, expected_prev, atol=1e-5), (
"prev_positions should be fractional coords in the adjusted cell frame. "
f"max diff = {(opt_state.prev_positions - expected_prev).abs().max():.2e}"
)
Loading
Loading