Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,57 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) ->
f"Energy for system {step} doesn't match position only optimization: "
f"system={energy_unit_cell}, individual={individual_energies_fire[step]}"
)


# Test for charge and spin preservation
# GitHub Issue https://github.com/TorchSim/torch-sim/issues/389
@pytest.mark.parametrize(
("optimizer_fn", "cell_filter"),
[
(ts.Optimizer.fire, None),
(ts.Optimizer.gradient_descent, None),
(ts.Optimizer.fire, ts.CellFilter.unit),
(ts.Optimizer.gradient_descent, ts.CellFilter.frechet),
],
)
def test_optimizer_preserves_charge_spin(
optimizer_fn: ts.Optimizer,
cell_filter: ts.CellFilter | None,
ar_supercell_sim_state: SimState,
lj_model: ModelInterface,
) -> None:
"""Test that optimizers preserve charge and spin through initialization and steps."""
# Add perturbation to positions for meaningful optimization
ar_supercell_sim_state.positions = (
ar_supercell_sim_state.positions
+ torch.randn_like(ar_supercell_sim_state.positions) * 0.1
)

# Set non-zero charge and spin values
original_charge = torch.tensor(
[5.0], device=ar_supercell_sim_state.device, dtype=ar_supercell_sim_state.dtype
)
original_spin = torch.tensor(
[6.0], device=ar_supercell_sim_state.device, dtype=ar_supercell_sim_state.dtype
)
ar_supercell_sim_state.charge = original_charge.clone()
ar_supercell_sim_state.spin = original_spin.clone()

init_fn, step_fn = ts.OPTIM_REGISTRY[optimizer_fn]
opt_state = init_fn(
model=lj_model, state=ar_supercell_sim_state, cell_filter=cell_filter
)

# Verify after initialization
assert torch.allclose(opt_state.charge, original_charge)
assert torch.allclose(opt_state.spin, original_spin)

# Run several optimization steps and verify preservation
for _ in range(3):
if optimizer_fn == ts.Optimizer.fire:
opt_state = step_fn(state=opt_state, model=lj_model, dt_max=0.3)
else:
opt_state = step_fn(state=opt_state, model=lj_model, pos_lr=0.01, cell_lr=0.1)

assert torch.allclose(opt_state.charge, original_charge)
assert torch.allclose(opt_state.spin, original_spin)
2 changes: 2 additions & 0 deletions torch_sim/optimizers/fire.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def fire_init(
"system_idx": state.system_idx.clone(),
"_constraints": state.constraints,
"pbc": state.pbc,
"charge": state.charge.clone(),
"spin": state.spin.clone(),
# Optimization state
"forces": forces,
"energy": energy,
Expand Down
2 changes: 2 additions & 0 deletions torch_sim/optimizers/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def gradient_descent_init(
"atomic_numbers": state.atomic_numbers,
"system_idx": state.system_idx,
"_constraints": state.constraints,
"charge": state.charge,
"spin": state.spin,
}

if cell_filter is not None: # Create cell optimization state
Expand Down
Loading