diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index decebe76e..0acb9835a 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -907,3 +907,57 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> f"Energy for system {step} doesn't match position only optimization: " f"system={energy_unit_cell}, individual={individual_energies_fire[step]}" ) + + +# Test for charge and spin preservation +# GitHub Issue https://github.com/TorchSim/torch-sim/issues/389 +@pytest.mark.parametrize( + ("optimizer_fn", "cell_filter"), + [ + (ts.Optimizer.fire, None), + (ts.Optimizer.gradient_descent, None), + (ts.Optimizer.fire, ts.CellFilter.unit), + (ts.Optimizer.gradient_descent, ts.CellFilter.frechet), + ], +) +def test_optimizer_preserves_charge_spin( + optimizer_fn: ts.Optimizer, + cell_filter: ts.CellFilter | None, + ar_supercell_sim_state: SimState, + lj_model: ModelInterface, +) -> None: + """Test that optimizers preserve charge and spin through initialization and steps.""" + # Add perturbation to positions for meaningful optimization + ar_supercell_sim_state.positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + # Set non-zero charge and spin values + original_charge = torch.tensor( + [5.0], device=ar_supercell_sim_state.device, dtype=ar_supercell_sim_state.dtype + ) + original_spin = torch.tensor( + [6.0], device=ar_supercell_sim_state.device, dtype=ar_supercell_sim_state.dtype + ) + ar_supercell_sim_state.charge = original_charge.clone() + ar_supercell_sim_state.spin = original_spin.clone() + + init_fn, step_fn = ts.OPTIM_REGISTRY[optimizer_fn] + opt_state = init_fn( + model=lj_model, state=ar_supercell_sim_state, cell_filter=cell_filter + ) + + # Verify after initialization + assert torch.allclose(opt_state.charge, original_charge) + assert torch.allclose(opt_state.spin, original_spin) + + # Run several optimization steps and verify preservation + for _ in range(3): + if optimizer_fn == ts.Optimizer.fire: + opt_state = step_fn(state=opt_state, model=lj_model, dt_max=0.3) + else: + opt_state = step_fn(state=opt_state, model=lj_model, pos_lr=0.01, cell_lr=0.1) + + assert torch.allclose(opt_state.charge, original_charge) + assert torch.allclose(opt_state.spin, original_spin) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 5583cc0c6..e69a4955f 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -82,6 +82,8 @@ def fire_init( "system_idx": state.system_idx.clone(), "_constraints": state.constraints, "pbc": state.pbc, + "charge": state.charge.clone(), + "spin": state.spin.clone(), # Optimization state "forces": forces, "energy": energy, diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index d6bf52f57..283fe7c80 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -62,6 +62,8 @@ def gradient_descent_init( "atomic_numbers": state.atomic_numbers, "system_idx": state.system_idx, "_constraints": state.constraints, + "charge": state.charge, + "spin": state.spin, } if cell_filter is not None: # Create cell optimization state