From 40836bc725571c9d6fc30d885d1b069e9e198fc4 Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Mon, 14 Apr 2025 10:20:30 -0700 Subject: [PATCH 01/13] adds comparative test for fire optimizer with ase --- .../models/test_torchsim_vs_ase_fire_mace.py | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 tests/models/test_torchsim_vs_ase_fire_mace.py diff --git a/tests/models/test_torchsim_vs_ase_fire_mace.py b/tests/models/test_torchsim_vs_ase_fire_mace.py new file mode 100644 index 000000000..8c0ebc849 --- /dev/null +++ b/tests/models/test_torchsim_vs_ase_fire_mace.py @@ -0,0 +1,195 @@ +import copy +import random +import typing + +import numpy as np +import pytest +import torch +from ase.filters import FrechetCellFilter +from ase.optimize import FIRE + +from torch_sim.io import state_to_atoms +from torch_sim.optimizers import frechet_cell_fire +from torch_sim.state import SimState # Removed unused concatenate_states + + +try: + from mace.calculators.foundations_models import mace_mp + + from torch_sim.models.mace import MaceModel + + if typing.TYPE_CHECKING: + from mace.calculators import ( + MACECalculator, # Import ASE calculator for type hints only + ) + + MACE_AVAILABLE = True +except ImportError: + MACE_AVAILABLE = False + # If MACE is not available, skip all tests in this module. + pytestmark = pytest.mark.skipif(not MACE_AVAILABLE, reason="MACE not installed") + + +# Seed everything +torch.manual_seed(123) +# Replace legacy np.random.seed with Generator API +rng = np.random.default_rng(123) +random.seed(123) + + +@pytest.fixture +def torchsim_mace_mp_model(device: torch.device) -> MaceModel: + """Provides a MACE MP model instance for the optimizer tests.""" + # Use float64 for potentially higher precision needed in optimization + dtype = torch.float64 + # Ensure default_dtype is passed correctly for the raw model + mace_model_raw = mace_mp( + model="small", return_raw_model=True, default_dtype=str(dtype).split(".")[-1] + ) + return MaceModel( + model=mace_model_raw, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, # Stress needed for cell optimization + ) + + +@pytest.fixture +def ase_mace_mp_calculator( + device: torch.device, +) -> "MACECalculator": # Use quotes if MACECalculator type hint causes issues + """Provides an ASE MACECalculator instance using mace_mp.""" + # Ensure dtype matches the one used in the torchsim fixture (float64) + dtype_str = str(torch.float64).split(".")[-1] + # Use the mace_mp function to get the ASE calculator directly + return mace_mp( + model="small", + device=str(device), # MACECalculator expects device as string + default_dtype=dtype_str, # Match dtype + dispersion=False, # Assuming no dispersion for this test + ) + + +def test_unit_cell_frechet_fire_vs_ase( + ar_supercell_sim_state: SimState, + # Use the MACE model fixtures defined in this file + torchsim_mace_mp_model: MaceModel, + ase_mace_mp_calculator: "MACECalculator", +) -> None: + """Compare Frechet Cell FIRE optimizer with ASE's FIRE + ExpCellFilter using MACE.""" + # pytest.importorskip("ase") # ASE import handled by module-level skipif + + # Use float64 for consistency with the MACE model fixture + dtype = torch.float64 + device = torchsim_mace_mp_model.device + + # --- Setup Initial State with float64 --- + initial_state = copy.deepcopy(ar_supercell_sim_state) + initial_state = initial_state.to(dtype=dtype, device=device) + + generator = torch.Generator(device=initial_state.device) + generator.manual_seed(123) # Seed for reproducibility + initial_state.positions += ( + torch.randn( + initial_state.positions.shape, + device=initial_state.device, + dtype=initial_state.dtype, + generator=generator, + ) + * 0.1 + ) + # Ensure grads are enabled for both positions and cell + initial_state.positions = initial_state.positions.detach().requires_grad_( + requires_grad=True + ) + initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) + + n_steps = 20 + dt_max = 0.3 + dt_start = 0.1 + + # --- Run Custom Frechet Cell FIRE with MACE model --- + custom_init_fn, custom_update_fn = frechet_cell_fire( + model=torchsim_mace_mp_model, # Use torch-sim MACE wrapper + dt_max=dt_max, + dt_start=dt_start, + ) + custom_state = custom_init_fn(initial_state) + initial_custom_energy = custom_state.energy.item() # Initial energy check removed + + # --- Setup ASE System with native MACE calculator --- + ase_atoms = state_to_atoms(initial_state)[0] + # Use the native ASE MACE calculator fixture, not the removed TorchCalculator + ase_atoms.calc = ase_mace_mp_calculator + initial_ase_energy = ase_atoms.get_potential_energy() # Initial energy check removed + + # Initial Energy Check --- + assert abs(initial_custom_energy - initial_ase_energy) < 1e-7, ( + "Initial energies differ significantly" + ) + + # --- Continue Custom Optimization --- + for _ in range(n_steps): + custom_state = custom_update_fn(custom_state) + + # --- Run ASE FIRE with ExpCellFilter --- + filtered_atoms = FrechetCellFilter(ase_atoms) + ase_opt = FIRE( + filtered_atoms, + trajectory=None, + logfile=None, + dt=dt_start, + dtmax=dt_max, + ) + + try: + ase_opt.run(fmax=1e-4, steps=n_steps) + except (ValueError, RuntimeError) as e: + # Catch specific exceptions instead of blind Exception + print(f"ASE optimization failed: {e}") + pytest.fail("ASE optimization step failed.") + + # --- Compare Results (between custom_state and ase_sim_state) --- + final_custom_energy = custom_state.energy.item() + final_custom_forces_max = torch.norm(custom_state.forces, dim=-1).max().item() + + final_custom_pos = custom_state.positions.detach() + final_custom_cell = custom_state.cell.squeeze(0).detach() + + final_ase_energy = ase_atoms.get_potential_energy() + final_ase_forces = torch.tensor(ase_atoms.get_forces(), device=device, dtype=dtype) + + if final_ase_forces is not None: + final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item() + else: + final_ase_forces_max = float("nan") + + final_ase_pos = torch.tensor(ase_atoms.get_positions(), device=device, dtype=dtype) + final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype) + + # Compare energies (use looser tolerance for ML potential comparison) + assert abs(final_custom_energy - final_ase_energy) < 5e-2, ( + f"Final energies differ significantly after {n_steps} steps: " + f"Custom={final_custom_energy:.6f}, ASE_State={final_ase_energy:.6f}" + ) + + # Compare forces (report) + print( + f"Max Force ({n_steps} steps): Custom={final_custom_forces_max:.4f}, " + f"ASE_State={final_ase_forces_max:.4f}" + ) + + # Compare positions (looser tolerance) + pos_diff = torch.norm(final_custom_pos - final_ase_pos, dim=-1).mean().item() + assert pos_diff < 1.0, ( + f"Final positions differ significantly (avg displacement: {pos_diff:.4f})" + ) + + # Compare cell matrices (looser tolerance) + cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() + assert cell_diff < 1.0, ( + f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})" + f"\nCustom Cell:\n{final_custom_cell}" + f"\nASE_State Cell:\n{final_ase_cell}" + ) From 2ba8ef6ff26ab7b7e8cff8f929d8a18cd4185fdc Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Mon, 14 Apr 2025 10:25:57 -0700 Subject: [PATCH 02/13] rm a few comments --- .../models/test_torchsim_vs_ase_fire_mace.py | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/models/test_torchsim_vs_ase_fire_mace.py b/tests/models/test_torchsim_vs_ase_fire_mace.py index 8c0ebc849..250b58313 100644 --- a/tests/models/test_torchsim_vs_ase_fire_mace.py +++ b/tests/models/test_torchsim_vs_ase_fire_mace.py @@ -1,6 +1,5 @@ import copy import random -import typing import numpy as np import pytest @@ -10,29 +9,23 @@ from torch_sim.io import state_to_atoms from torch_sim.optimizers import frechet_cell_fire -from torch_sim.state import SimState # Removed unused concatenate_states +from torch_sim.state import SimState try: + from mace.calculators import MACECalculator from mace.calculators.foundations_models import mace_mp from torch_sim.models.mace import MaceModel - if typing.TYPE_CHECKING: - from mace.calculators import ( - MACECalculator, # Import ASE calculator for type hints only - ) - MACE_AVAILABLE = True except ImportError: MACE_AVAILABLE = False - # If MACE is not available, skip all tests in this module. pytestmark = pytest.mark.skipif(not MACE_AVAILABLE, reason="MACE not installed") # Seed everything torch.manual_seed(123) -# Replace legacy np.random.seed with Generator API rng = np.random.default_rng(123) random.seed(123) @@ -42,7 +35,6 @@ def torchsim_mace_mp_model(device: torch.device) -> MaceModel: """Provides a MACE MP model instance for the optimizer tests.""" # Use float64 for potentially higher precision needed in optimization dtype = torch.float64 - # Ensure default_dtype is passed correctly for the raw model mace_model_raw = mace_mp( model="small", return_raw_model=True, default_dtype=str(dtype).split(".")[-1] ) @@ -51,34 +43,32 @@ def torchsim_mace_mp_model(device: torch.device) -> MaceModel: device=device, dtype=dtype, compute_forces=True, - compute_stress=True, # Stress needed for cell optimization + compute_stress=True, ) @pytest.fixture def ase_mace_mp_calculator( device: torch.device, -) -> "MACECalculator": # Use quotes if MACECalculator type hint causes issues +) -> MACECalculator: """Provides an ASE MACECalculator instance using mace_mp.""" # Ensure dtype matches the one used in the torchsim fixture (float64) dtype_str = str(torch.float64).split(".")[-1] # Use the mace_mp function to get the ASE calculator directly return mace_mp( model="small", - device=str(device), # MACECalculator expects device as string - default_dtype=dtype_str, # Match dtype - dispersion=False, # Assuming no dispersion for this test + device=str(device), + default_dtype=dtype_str, + dispersion=False, ) def test_unit_cell_frechet_fire_vs_ase( ar_supercell_sim_state: SimState, - # Use the MACE model fixtures defined in this file torchsim_mace_mp_model: MaceModel, - ase_mace_mp_calculator: "MACECalculator", + ase_mace_mp_calculator: MACECalculator, ) -> None: """Compare Frechet Cell FIRE optimizer with ASE's FIRE + ExpCellFilter using MACE.""" - # pytest.importorskip("ase") # ASE import handled by module-level skipif # Use float64 for consistency with the MACE model fixture dtype = torch.float64 From 6e14e2bd915f55fe093252b1f09dbaf7f21a93c8 Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Mon, 14 Apr 2025 10:43:24 -0700 Subject: [PATCH 03/13] updates CI for the test --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9a21e8890..cc43ed8f1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,6 +45,7 @@ jobs: --ignore=tests/models/test_orb.py \ --ignore=tests/models/test_sevennet.py \ --ignore=tests/models/test_metatensor.py + --ignore=tests/models/test_torchsim_vs_ase_fire_mace.py \ - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 @@ -64,6 +65,7 @@ jobs: - { name: fairchem, test_path: "tests/models/test_fairchem.py" } - { name: mace, test_path: "tests/models/test_mace.py" } - { name: mace, test_path: "tests/test_elastic.py" } + - { name: mace, test_path: "tests/models/test_torchsim_vs_ase_fire_mace.py" } - { name: mattersim, test_path: "tests/models/test_mattersim.py" } - { name: metatensor, test_path: "tests/models/test_metatensor.py" } - { name: orb, test_path: "tests/models/test_orb.py" } From f35462e6c0b2a0c8fcfe02260b381addd5df8319 Mon Sep 17 00:00:00 2001 From: Orion Cohen <27712051+orionarcher@users.noreply.github.com> Date: Mon, 14 Apr 2025 15:03:30 -0400 Subject: [PATCH 04/13] update changelog for v0.2.0 (#147) * update changelog for v0.2.0 * minor modification for PR template * formatting fixes * formatting and typos * remove contributors bc they aren't linked --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 606144d39..69828961d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,7 +81,6 @@ * New correlation function module, https://github.com/Radical-AI/torch-sim/pull/115 @stefanbringuier ### Documentation 📖 - * Imoved model documentation, https://github.com/Radical-AI/torch-sim/pull/121 @orionarcher * Plot of TorchSim module graph in docs, https://github.com/Radical-AI/torch-sim/pull/132 @janosh From 4a8ab14e46f6490486a39bd39e761b37697da883 Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Wed, 16 Apr 2025 08:20:55 -0700 Subject: [PATCH 05/13] update test with a harder system --- .../models/test_torchsim_vs_ase_fire_mace.py | 94 ++++++------------- 1 file changed, 30 insertions(+), 64 deletions(-) diff --git a/tests/models/test_torchsim_vs_ase_fire_mace.py b/tests/models/test_torchsim_vs_ase_fire_mace.py index 250b58313..e77239586 100644 --- a/tests/models/test_torchsim_vs_ase_fire_mace.py +++ b/tests/models/test_torchsim_vs_ase_fire_mace.py @@ -7,21 +7,17 @@ from ase.filters import FrechetCellFilter from ase.optimize import FIRE +import torch_sim as ts from torch_sim.io import state_to_atoms from torch_sim.optimizers import frechet_cell_fire from torch_sim.state import SimState -try: - from mace.calculators import MACECalculator - from mace.calculators.foundations_models import mace_mp +pytest.importorskip("mace") +from mace.calculators import MACECalculator +from mace.calculators.foundations_models import mace_mp - from torch_sim.models.mace import MaceModel - - MACE_AVAILABLE = True -except ImportError: - MACE_AVAILABLE = False - pytestmark = pytest.mark.skipif(not MACE_AVAILABLE, reason="MACE not installed") +from torch_sim.models.mace import MaceModel # Seed everything @@ -31,16 +27,19 @@ @pytest.fixture -def torchsim_mace_mp_model(device: torch.device) -> MaceModel: +def torchsim_mace_mp_model() -> MaceModel: """Provides a MACE MP model instance for the optimizer tests.""" # Use float64 for potentially higher precision needed in optimization dtype = torch.float64 + mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" mace_model_raw = mace_mp( - model="small", return_raw_model=True, default_dtype=str(dtype).split(".")[-1] + model=mace_checkpoint_url, + return_raw_model=True, + default_dtype=str(dtype).split(".")[-1], ) return MaceModel( model=mace_model_raw, - device=device, + device=torch.device("cpu"), dtype=dtype, compute_forces=True, compute_stress=True, @@ -48,47 +47,36 @@ def torchsim_mace_mp_model(device: torch.device) -> MaceModel: @pytest.fixture -def ase_mace_mp_calculator( - device: torch.device, -) -> MACECalculator: +def ase_mace_mp_calculator() -> MACECalculator: """Provides an ASE MACECalculator instance using mace_mp.""" # Ensure dtype matches the one used in the torchsim fixture (float64) + mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" dtype_str = str(torch.float64).split(".")[-1] # Use the mace_mp function to get the ASE calculator directly return mace_mp( - model="small", - device=str(device), + model=mace_checkpoint_url, + device=torch.device("cpu"), default_dtype=dtype_str, dispersion=False, ) def test_unit_cell_frechet_fire_vs_ase( - ar_supercell_sim_state: SimState, + rattled_sio2_sim_state: SimState, torchsim_mace_mp_model: MaceModel, ase_mace_mp_calculator: MACECalculator, ) -> None: - """Compare Frechet Cell FIRE optimizer with ASE's FIRE + ExpCellFilter using MACE.""" + """Compare Frechet Cell FIRE optimizer with + ASE's FIRE + FrechetCellFilter using MACE.""" # Use float64 for consistency with the MACE model fixture dtype = torch.float64 device = torchsim_mace_mp_model.device # --- Setup Initial State with float64 --- - initial_state = copy.deepcopy(ar_supercell_sim_state) + initial_state = copy.deepcopy(rattled_sio2_sim_state) initial_state = initial_state.to(dtype=dtype, device=device) - generator = torch.Generator(device=initial_state.device) - generator.manual_seed(123) # Seed for reproducibility - initial_state.positions += ( - torch.randn( - initial_state.positions.shape, - device=initial_state.device, - dtype=initial_state.dtype, - generator=generator, - ) - * 0.1 - ) # Ensure grads are enabled for both positions and cell initial_state.positions = initial_state.positions.detach().requires_grad_( requires_grad=True @@ -96,49 +84,27 @@ def test_unit_cell_frechet_fire_vs_ase( initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) n_steps = 20 - dt_max = 0.3 - dt_start = 0.1 + force_tol = 0.02 # --- Run Custom Frechet Cell FIRE with MACE model --- - custom_init_fn, custom_update_fn = frechet_cell_fire( - model=torchsim_mace_mp_model, # Use torch-sim MACE wrapper - dt_max=dt_max, - dt_start=dt_start, - ) - custom_state = custom_init_fn(initial_state) - initial_custom_energy = custom_state.energy.item() # Initial energy check removed + custom_state = ts.optimize( + system=initial_state, + model=torchsim_mace_mp_model, + optimizer=frechet_cell_fire, + max_steps=n_steps, + convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol), + ) # --- Setup ASE System with native MACE calculator --- ase_atoms = state_to_atoms(initial_state)[0] # Use the native ASE MACE calculator fixture, not the removed TorchCalculator ase_atoms.calc = ase_mace_mp_calculator - initial_ase_energy = ase_atoms.get_potential_energy() # Initial energy check removed - # Initial Energy Check --- - assert abs(initial_custom_energy - initial_ase_energy) < 1e-7, ( - "Initial energies differ significantly" - ) - - # --- Continue Custom Optimization --- - for _ in range(n_steps): - custom_state = custom_update_fn(custom_state) - - # --- Run ASE FIRE with ExpCellFilter --- + # --- Run ASE FIRE with FrechetCellFilter --- filtered_atoms = FrechetCellFilter(ase_atoms) - ase_opt = FIRE( - filtered_atoms, - trajectory=None, - logfile=None, - dt=dt_start, - dtmax=dt_max, - ) + ase_opt = FIRE(filtered_atoms) - try: - ase_opt.run(fmax=1e-4, steps=n_steps) - except (ValueError, RuntimeError) as e: - # Catch specific exceptions instead of blind Exception - print(f"ASE optimization failed: {e}") - pytest.fail("ASE optimization step failed.") + ase_opt.run(fmax=force_tol, steps=n_steps) # --- Compare Results (between custom_state and ase_sim_state) --- final_custom_energy = custom_state.energy.item() From e17db827de8ae191a3f4ee8f01d1ccf450a50627 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 17 Apr 2025 15:23:44 -0400 Subject: [PATCH 06/13] fix torch.device not iterable error in test_torchsim_vs_ase_fire_mace.py --- tests/models/test_torchsim_vs_ase_fire_mace.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/test_torchsim_vs_ase_fire_mace.py b/tests/models/test_torchsim_vs_ase_fire_mace.py index e77239586..6d08a89f4 100644 --- a/tests/models/test_torchsim_vs_ase_fire_mace.py +++ b/tests/models/test_torchsim_vs_ase_fire_mace.py @@ -39,7 +39,7 @@ def torchsim_mace_mp_model() -> MaceModel: ) return MaceModel( model=mace_model_raw, - device=torch.device("cpu"), + device="cpu", dtype=dtype, compute_forces=True, compute_stress=True, @@ -55,7 +55,7 @@ def ase_mace_mp_calculator() -> MACECalculator: # Use the mace_mp function to get the ASE calculator directly return mace_mp( model=mace_checkpoint_url, - device=torch.device("cpu"), + device="cpu", default_dtype=dtype_str, dispersion=False, ) @@ -74,8 +74,7 @@ def test_unit_cell_frechet_fire_vs_ase( device = torchsim_mace_mp_model.device # --- Setup Initial State with float64 --- - initial_state = copy.deepcopy(rattled_sio2_sim_state) - initial_state = initial_state.to(dtype=dtype, device=device) + initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device) # Ensure grads are enabled for both positions and cell initial_state.positions = initial_state.positions.detach().requires_grad_( From c04f41f3db21135bf582950926abb94f422df859 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 2 May 2025 10:33:51 +0100 Subject: [PATCH 07/13] fix: should compare the row_vector cell, clean: fix changelog typo --- CHANGELOG.md | 3 ++- tests/models/test_torchsim_vs_ase_fire_mace.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69828961d..cec422cbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,7 +81,8 @@ * New correlation function module, https://github.com/Radical-AI/torch-sim/pull/115 @stefanbringuier ### Documentation 📖 -* Imoved model documentation, https://github.com/Radical-AI/torch-sim/pull/121 @orionarcher + +* Improved model documentation, https://github.com/Radical-AI/torch-sim/pull/121 @orionarcher * Plot of TorchSim module graph in docs, https://github.com/Radical-AI/torch-sim/pull/132 @janosh ### House-Keeping 🧹 diff --git a/tests/models/test_torchsim_vs_ase_fire_mace.py b/tests/models/test_torchsim_vs_ase_fire_mace.py index 6d08a89f4..6da47acd2 100644 --- a/tests/models/test_torchsim_vs_ase_fire_mace.py +++ b/tests/models/test_torchsim_vs_ase_fire_mace.py @@ -110,7 +110,7 @@ def test_unit_cell_frechet_fire_vs_ase( final_custom_forces_max = torch.norm(custom_state.forces, dim=-1).max().item() final_custom_pos = custom_state.positions.detach() - final_custom_cell = custom_state.cell.squeeze(0).detach() + final_custom_cell = custom_state.row_vector_cell.squeeze(0).detach() final_ase_energy = ase_atoms.get_potential_energy() final_ase_forces = torch.tensor(ase_atoms.get_forces(), device=device, dtype=dtype) From 1edca85386f65a8b9db73d349297f4fff7e713a2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 2 May 2025 11:02:41 +0100 Subject: [PATCH 08/13] clean: delete .coverage and newline for pytest command --- .coverage | Bin 53248 -> 0 bytes .github/workflows/test.yml | 9 ++++++--- 2 files changed, 6 insertions(+), 3 deletions(-) delete mode 100644 .coverage diff --git a/.coverage b/.coverage deleted file mode 100644 index 49e4419cc519d524d37c828547366b443bd89c47..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53248 zcmeI53ve69d4TsG#NmJgmJ~&SB=|T=Q%g!Dk(3xcY}tftX{L76*iM{xCg2e~l28a7 zgaakf6Ip=l)@|yvX=Yldjbl5RNzq>VwvNVmDH+FB zY{km|@7)ogL^_9L#;KiuN$l?J?*H$<-#!it@Ro15@o>Iko3l>2U{%artdmI+%a|r( zOovwjUV+mfq6QZbN-goVqB`4gb1A{!#Ei&$8GmQudcNJbD{)!;>qbF;PyBAp(sPhO zjtCF|B0vOww+Ym4G_+J-pLF_U#hPBQD`hKZmpwmm&Fc>xxap92(}AlGA2Pi(b4$zw zw|%>Lz$`n*&5~U=db5j zbD*MB&KIomlKC!sX`3kE*H*DlRJ;UGV&~_ISi;=mi*Gf{_N-mDi#gl%tk{yD*&5q= z>#3NQ>h6|qw?&(*at?muD_R2y&8A#C3yN}1c?PYYE?dRiyzOo?t>qTx9I*35#cOTZ zDU4Peb0J?8%`Ce4O5Q1&_6a+;Sg~hTZQwS97C3;{+8WR`g$lQt=rn~_k@SoE z!n0sp(Fw3N1+i)n7B&|b3)6u7$YP;{odXhH#T92Q$+cuFj^EhOeX~kmHrua#TXhzZ5W?~~?1pHn+lG@iTF&LHh3HRs%yx+=Nixo#)XQA`6 zuPj)(AahD_dib1oSyvESM{I$mo&XSdb1ZUC0g6YhPDTC6c z*EuknTj$JZ0K+QdL^B<~AXxfa9c^eV*kDdCIMXY~niw&z4HE>`f}q&!VljtP#l<${ zoMU#`nzMKKx!$ty+a5?TbjI~a>d^37LuYviWXG4Iv07h5OQqA&XzqkBATQP)! zo?K-a4YCTh9v~Um&=a(k!Vd*wNy`!RP$y8 zu1ompSkVYfX-xJn`PswAWM6 zybjnysjZ?oOz#^U*UiIi6>nSa>M5(~h%XTSf2Q020@_P-osXg$0)pmdn3OI| z(Tj8jO8s11u>$=&O27oQ7OWX>x?`ubXgoe|U!u;t8y&PNb|qi1O)Mge5d-VE<(j#& z4ViE+-MU$Mv0;F{l^sVPcgazFD_0wqwA8C#E!7tN$>^;!Sr=~V1uHPyzm!)Dl8Z8% zo`GnEjaMfs{t1`uLq|+3PCCVy-sW^Ztx}kJcD0yM1Tko0U|&IhyW2F z0z`la5CI}U1c<;ZkANb{l7{<#ng2KAKL(E+5g-CYfCvx)B0vO)01+SpM1Tko0V447 zBw)zOAP@e;V~k5mI_3Wnz@8nuckGKnt7LwL@iY8qFV6;&)kJ^@5CI}U1c(3;AOb{y z2oM1xKm>?DK)_H2rQlNlIi{p_|3d)${(qD9FN{CP-^bs;JCpyOyeH`-_a-BWhZD8L zF5`LQtHz%h*Ba6I=i`4Ezaico`&R5jv21Lk{)GPbdP(1_E76Cde-gbWDuWtwM1Tko z0U|&Ih`>unVBaQ%Eseg>g(oIQk61;=ogbaCkHOIZ<=p&sH;;z|KrD-~<4*ZqvkT5~ zchuUMgC`~@H>#|rv@4s<=8O4CHVX<4j(~zhI0dD$1E(}p^0o`lk(bN%0-VwS4`G9r zj0svg+Sbxy$L?W;x#1MRi4I5X957jens*H;tPoDFyjUy(Vi4Lr7`ig21P_+xZ}#N& z3@B_dTxA73=phToMJzZ_;huD8LGf^B*@EVwrC0R>SSf6kGxmZziqEiC?4s+GLH>$9 zg`I3me6DPj=1VqKzQ0#tr`i&RW2h?l+;0J^PPJRLYag?riz*4I_TO|IJ}F+X{!Q8&W!ph zTxYhDbxZSbW(la8>;hFGSAkU(T+;~(;=iUX1sJMm#a^zk2|gGtzUq_Zk1=U zN8l6=(WF@DCczR|bm3%=0?Li)3Y!aO1I*8Pka$g0VMoJB6ydC#`Dt%4+My{d7fvjX ztKN7dbho;=W`W55j&_$E?2z)TH~9~!&>dPh3)kp`E0lKEkCvO*UODs(z%yjT=z^1i zpxqOa!klo1x73H<|8LU1%J@V4FXA`qA5WH&8ow#NKDH_O4c-;|uh^9SKgOTO|0s61 zq3cJJ`}BR$_e2jG=XEuCZ}eXik0eeT*Tuh(m@vK;RX`0nB0vO)01+SpMBpVPaJjM+ za*ucM{(n*VWiCXgke@Wyb11{^^A=$o-av z(hYcxf$lS?S~wNl6l3?^BCay<4}z;uV-l~b@G{}Wz##TCNSuVv#Lo=RjQOO>cYJ-Dl{^f%* z`;^P7{|RaX*MLa%o`1X$T(2HtQ}@nYXpw6zSI5d3yZd|YGiw!TEiU4L%p)i>H~?+i zKOj&2ANxR1hl;c&vFtCdPh(dOrJ*bT2C^S#r-TY-!imO=-Pfcbk{`S1+AE7b_CaIm zD=%oh{ZQY|emV0&De+}iX~LRF_d(9yKAC<0j>OJxAqq8Ce}u7UcQ0ga>HYY_3Y%vC z;ThW5n?hZwBUgUuV+VTK&pZ@ZT-~$*a(;h<%t{|p7!ZVk^>-IA$8(}&|Mngz^o|~x z-T&ki%j}OpW%&z;`%(wUqP70@*pn0MRaXB?rEWAHX;z3GFP!Oz89xlJS&$(9Wi@6l zM}n;CBqj^j@WTv*SndV?+7E_WE}s!BI(K$MTV}gu_DPpbM5X6f?Fa)YEE_FPtOJ=N z>z+D_)1{6*(fjDL8Z%q80}_dyX&us6PIuvmIM@Z<4>OJ5<>`CQH?ay7TV2MxsQYP| zFP`#a&0w$@4N9dt!LqSV&$0$~i6JVeZsw421BbdFox{RtujiJVB`(!e5;Y{%U1ulz z7%M%w@2yKOd}9AMN7JbUW+aZNRSyx^jm>A$=^+CecH4gA{)Z*jh@3;kU@TMxs!c5N z3-p#<)o<}~f)~-)+0RMOt&0Dmk}rChWasuZ&Q+yGbAq&#hvCTXlb%~G&)*_=SRPrZ zGxnPsBWk{?HxA@>#-AF;#??gy#?3&8b=G=fI5S=qlRpTfF(ZmMgwOVRfmlwT_NMKC zNP5=qFj!bH2AEOhUzHhq>&>FT3w{Bd0FVoc#6VaT$OXR^*6^xeMM$XgQ#uaSTlA0g zF%u0%2Mdu20n`OCthxbL-ufgWxQt7}BonV6V|7?g;4TulC@aCFfNPP~@RC6lIq*gD{WN`9(5WS4*b0{|<)K<~Yh~)tmqGdr$l?+MM^?_-L&Aw0lKo5TZ zKf=?DKhK}$KjcsH@AAj^*Won4ukg?Df8w9wALIA(zvl1ff68y;C;2h%@Of_Wx4~I~ zhxs9X72m_h_*T9VRFESAM1Tko0U|&IhyW2F0z`la5CJ0aawHIeExa@`A_CJC!SJvM zhK58iI4FXF0THCrBIxfIL0_K;dV58XN{L{@1`+i1h+zGC5p;KpVBI>Y5g-CY;FU`NzyGKG|0}nDsbxff2oM1xKm>>Y5g-CYfCvx) zB0vO?06zb}iQkBS|NjjC3I75A9)E&=8=e9927j1;jsG+MBEKJ=1Gta>1OHq8VSX3C z6P^WlFaKkHE5C&=@d`W-aD>}@iocz|3Dl4y0z`la5CI}U1c(3;AOb{y2oM1x@Vh}k z*#xKRD;v=pLC-{Q7`-9%2GJWpFO6P5dVT2iqL)H%1A0B^tw*mLy>;kyq1TBXM=yz9 g0zCu0IC?Skbo8R=Y3N1J>p)LMPeD&c56}PqCFAvkPyhe` diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cc43ed8f1..4b853d8e6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,11 +40,14 @@ jobs: - name: Run Tests run: | pytest --cov=torch_sim --cov-report=xml \ - --ignore=tests/models/test_mace.py \ + --ignore=tests/test_elastic.py \ --ignore=tests/models/test_fairchem.py \ + --ignore=tests/models/test_graphpes.py \ + --ignore=tests/models/test_mace.py \ --ignore=tests/models/test_orb.py \ --ignore=tests/models/test_sevennet.py \ - --ignore=tests/models/test_metatensor.py + --ignore=tests/models/test_mattersim.py \ + --ignore=tests/models/test_metatensor.py \ --ignore=tests/models/test_torchsim_vs_ase_fire_mace.py \ - name: Upload coverage to Codecov @@ -63,6 +66,7 @@ jobs: - { python: "3.12", resolution: lowest-direct } model: - { name: fairchem, test_path: "tests/models/test_fairchem.py" } + - { name: graphpes, test_path: "tests/models/test_graphpes.py" } - { name: mace, test_path: "tests/models/test_mace.py" } - { name: mace, test_path: "tests/test_elastic.py" } - { name: mace, test_path: "tests/models/test_torchsim_vs_ase_fire_mace.py" } @@ -70,7 +74,6 @@ jobs: - { name: metatensor, test_path: "tests/models/test_metatensor.py" } - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } - - { name: graphpes, test_path: "tests/models/test_graphpes.py" } runs-on: ${{ matrix.os }} steps: From b0c308826ab11e951fe4c92956b2adff73fe5bd9 Mon Sep 17 00:00:00 2001 From: Myles Stapelberg <35986272+mstapelberg@users.noreply.github.com> Date: Wed, 14 May 2025 11:50:33 -0400 Subject: [PATCH 09/13] Introduce ASE-style `FIRE` optimizer (departing from velocity Verlet in orig FIRE paper) and improve coverage in `test_optimizers.py` (#174) * feat(fire-optimizer-changes) Update fire_step in optimizers.py based feature/neb-workflow * reset optimizers.py to main version prior to adding updated changes * (feat:fire-optimizer-changes) - Added ase_fire_step and renamed fire_step to vv_fire_step. Allowed for selection of md_flavor * (feat:fire-optimizer-changes) - lint check on optimizers.py with ruff * (feat:fire-optimizer-changes) - added test cases and example script in examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py * (feat:fire-optimizer-changes) - updated FireState, UnitCellFireState, and FrechetCellFireState to have md_flavor to select vv or ase. ASE currently coverges in 1/3 as long. test cases for all three FIRE schemes added to test_optimizers.py with both md_flavors * ruff auto format * minor refactor of 7.6_Compare_ASE_to_VV_FIRE.py * refactor optimizers.py: define MdFlavor type alias for SSoT on MD flavors * new optimizer tests: FIRE and UnitCellFIRE initialization with dictionary states, md_flavor validation, non-positive volume warnings brings optimizers.py test coverage up to 96% * cleanup test_optimizers.py: parameterize tests for FIRE and UnitCellFIRE initialization and batch consistency checks maintains same 96% coverage * refactor optimizers.py: consolidate vv_fire_step logic into a single _vv_fire_step function modified by functools.partial for different unit cell optimizations (unit/frechet/bare fire=no cell relax) - more concise and maintainable code * same as prev commit but for _ase_fire_step instead of _vv_fire_step * (feat:fire-optimizer-changes) - added references to ASE implementation of FIRE and a link to the original FIRE paper. * (feat:fire-optimizer-changes) switched md_flavor type from str to MdFlavor and set default to ase_fire_step * pytest.mark.xfail frechet_cell_fire with ase_fire flavor, reason: shows asymmetry in batched mode, batch 0 stalls * rename maxstep to max_step for consistent snake_case fix RuntimeError: a leaf Variable that requires grad is being used in an in-place operation: 7. Position / cell update state.positions += dr_atom * unskip frechet_cell_fire in test_optimizer_batch_consistency, can no longer repro error locally * code cleanup * bumpy set-up action to v6, more descriptive CI test names * pin to fairchem_core-1.10.0 in CI * explain differences between vv_fire and ase_fire and link references in fire|unit_cell_fire|frechet_cell_fire doc strings --------- Co-authored-by: Janosh Riebesell --- .github/workflows/docs.yml | 2 +- .github/workflows/test.yml | 20 +- .gitignore | 2 +- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 263 ++++ pyproject.toml | 7 +- tests/test_optimizers.py | 1148 ++++++++--------- torch_sim/optimizers.py | 973 +++++++------- torch_sim/unbatched/unbatched_optimizers.py | 27 +- 8 files changed, 1373 insertions(+), 1069 deletions(-) create mode 100644 examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 8ddc37e7f..af5da9aec 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,7 +27,7 @@ jobs: python-version: "3.11" - name: Set up uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v6 - name: Install dependencies run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4b853d8e6..58d0eee19 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,12 +32,12 @@ jobs: python-version: ${{ matrix.version.python }} - name: Set up uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v6 - name: Install torch_sim run: uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system - - name: Run Tests + - name: Run core tests run: | pytest --cov=torch_sim --cov-report=xml \ --ignore=tests/test_elastic.py \ @@ -69,7 +69,10 @@ jobs: - { name: graphpes, test_path: "tests/models/test_graphpes.py" } - { name: mace, test_path: "tests/models/test_mace.py" } - { name: mace, test_path: "tests/test_elastic.py" } - - { name: mace, test_path: "tests/models/test_torchsim_vs_ase_fire_mace.py" } + - { + name: mace, + test_path: "tests/models/test_torchsim_vs_ase_fire_mace.py", + } - { name: mattersim, test_path: "tests/models/test_mattersim.py" } - { name: metatensor, test_path: "tests/models/test_metatensor.py" } - { name: orb, test_path: "tests/models/test_orb.py" } @@ -84,8 +87,9 @@ jobs: if: ${{ matrix.model.name == 'fairchem' }} uses: actions/checkout@v4 with: - repository: "FAIR-Chem/fairchem" - path: "fairchem-repo" + repository: FAIR-Chem/fairchem + path: fairchem-repo + ref: fairchem_core-1.10.0 - name: Set up Python uses: actions/setup-python@v5 @@ -93,7 +97,7 @@ jobs: python-version: ${{ matrix.version.python }} - name: Set up uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v6 - name: Install fairchem repository and dependencies if: ${{ matrix.model.name == 'fairchem' }} @@ -119,7 +123,7 @@ jobs: if: ${{ matrix.model.name != 'fairchem' }} run: uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system - - name: Run Tests with Coverage + - name: Run ${{ matrix.model.test_path }} tests run: | pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} @@ -161,7 +165,7 @@ jobs: python-version: 3.11 - name: Set up uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v6 - name: Run example run: uv run --with . ${{ matrix.example }} diff --git a/.gitignore b/.gitignore index 29646ebf1..9c028c814 100644 --- a/.gitignore +++ b/.gitignore @@ -30,7 +30,7 @@ docs/reference/torch_sim.* # coverage coverage.xml -.coverage +.coverage* # env uv.lock diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py new file mode 100644 index 000000000..a6aa07e65 --- /dev/null +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -0,0 +1,263 @@ +"""Structural optimization with MACE using FIRE optimizer. +Comparing the ASE and VV FIRE optimizers. +""" + +# /// script +# dependencies = [ +# "mace-torch>=0.3.12", +# ] +# /// + +import os +import time + +import numpy as np +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.mace import MaceModel +from torch_sim.optimizers import fire +from torch_sim.state import SimState + + +# Set device, data type and unit conversion +device = "cuda" if torch.cuda.is_available() else "cpu" +dtype = torch.float32 +unit_conv = ts.units.UnitConversion + +# Option 1: Load the raw model from the downloaded model +mace_checkpoint_url = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" +loaded_model = mace_mp( + model=mace_checkpoint_url, + return_raw_model=True, + default_dtype=dtype, + device=device, +) + +# Number of steps to run +N_steps = 10 if os.getenv("CI") else 500 + +# Set random seed for reproducibility +rng = np.random.default_rng(seed=0) + +# Create diamond cubic Silicon +si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat((4, 4, 4)) +si_dc.positions += 0.3 * rng.standard_normal(si_dc.positions.shape) + +# Create FCC Copper +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((5, 5, 5)) +cu_dc.positions += 0.3 * rng.standard_normal(cu_dc.positions.shape) + +# Create BCC Iron +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((5, 5, 5)) +fe_dc.positions += 0.3 * rng.standard_normal(fe_dc.positions.shape) + +si_dc_vac = si_dc.copy() +si_dc_vac.positions += 0.3 * rng.standard_normal(si_dc_vac.positions.shape) +# select 2 numbers in range 0 to len(si_dc_vac) +indices = rng.choice(len(si_dc_vac), size=2, replace=False) +for idx in indices: + si_dc_vac.pop(idx) + + +cu_dc_vac = cu_dc.copy() +cu_dc_vac.positions += 0.3 * rng.standard_normal(cu_dc_vac.positions.shape) +# remove 2 atoms from cu_dc_vac at random +indices = rng.choice(len(cu_dc_vac), size=2, replace=False) +for idx in indices: + index = idx + 3 + if index < len(cu_dc_vac): + cu_dc_vac.pop(index) + else: + print(f"Index {index} is out of bounds for cu_dc_vac") + cu_dc_vac.pop(0) + +fe_dc_vac = fe_dc.copy() +fe_dc_vac.positions += 0.3 * rng.standard_normal(fe_dc_vac.positions.shape) +# remove 2 atoms from fe_dc_vac at random +indices = rng.choice(len(fe_dc_vac), size=2, replace=False) +for idx in indices: + index = idx + 2 + if index < len(fe_dc_vac): + fe_dc_vac.pop(index) + else: + print(f"Index {index} is out of bounds for fe_dc_vac") + fe_dc_vac.pop(0) + + +# Create a list of our atomic systems +atoms_list = [si_dc, cu_dc, fe_dc, si_dc_vac, cu_dc_vac] + +# Print structure information +print(f"Silicon atoms: {len(si_dc)}") +print(f"Copper atoms: {len(cu_dc)}") +print(f"Iron atoms: {len(fe_dc)}") +print(f"Total number of structures: {len(atoms_list)}") + +# Create batched model +model = MaceModel( + model=loaded_model, + device=device, + compute_forces=True, + compute_stress=True, + dtype=dtype, + enable_cueq=False, +) + +# Convert atoms to state +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +# Run initial inference +initial_energies = model(state)["energy"] + + +def run_optimization( + initial_state: SimState, md_flavor: str, force_tol: float = 0.05 +) -> tuple[torch.Tensor, SimState]: + """Runs FIRE optimization and returns convergence steps.""" + print(f"\n--- Running optimization with MD Flavor: {md_flavor} ---") + start_time = time.perf_counter() + + # Re-initialize state and optimizer for this run + init_fn, update_fn = fire( + model=model, + md_flavor=md_flavor, + ) + fire_state = init_fn(initial_state.clone()) # Use a clone to start fresh + + batcher = ts.InFlightAutoBatcher( + model=model, + memory_scales_with="n_atoms", + max_memory_scaler=1000, + max_iterations=1000, # Increased max iterations + return_indices=True, # Ensure indices are returned + ) + + batcher.load_states(fire_state) + + total_structures = fire_state.n_batches + # Initialize convergence steps tensor (-1 means not converged yet) + convergence_steps = torch.full( + (total_structures,), -1, dtype=torch.long, device=device + ) + convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + + converged_tensor_global = torch.zeros( + total_structures, dtype=torch.bool, device=device + ) + global_step = 0 + all_converged_states = [] # Initialize list to store completed states + convergence_tensor_for_batcher = None # Initialize convergence tensor for batcher + + # Keep track of the last valid state for final collection + last_active_state = fire_state + + while True: # Loop until batcher indicates completion + # Get the next batch, passing the convergence status + result = batcher.next_batch(last_active_state, convergence_tensor_for_batcher) + + fire_state, converged_states_from_batcher, current_indices_list = result + all_converged_states.extend( + converged_states_from_batcher + ) # Add newly completed states + + if fire_state is None: # No more active states + print("All structures converged or batcher reached max iterations.") + break + + last_active_state = fire_state # Store the current active state + + # Get the original indices of the current active batch as a tensor + current_indices = torch.tensor( + current_indices_list, dtype=torch.long, device=device + ) + + # Optimize the current batch + steps_this_round = 10 + for _ in range(steps_this_round): + fire_state = update_fn(fire_state) + global_step += steps_this_round # Increment global step count + + # Check convergence *within the active batch* + convergence_tensor_for_batcher = convergence_fn(fire_state, None) + + # Update global convergence status and steps + # Identify structures in this batch that just converged + newly_converged_mask_local = convergence_tensor_for_batcher & ( + convergence_steps[current_indices] == -1 + ) + converged_indices_global = current_indices[newly_converged_mask_local] + + if converged_indices_global.numel() > 0: + # Mark convergence step + convergence_steps[converged_indices_global] = global_step + converged_tensor_global[converged_indices_global] = True + converged_indices = converged_indices_global.tolist() + + total_converged = converged_tensor_global.sum().item() / total_structures + print(f"{global_step=}: {converged_indices=}, {total_converged=:.2%}") + + # Optional: Print progress + if global_step % 50 == 0: # Reduced frequency + total_converged = converged_tensor_global.sum().item() / total_structures + active_structures = fire_state.n_batches if fire_state else 0 + print(f"{global_step=}: {active_structures=}, {total_converged=:.2%}") + + # After the loop, collect any remaining states that were active in the last batch + # result[1] contains states completed *before* the last next_batch call. + # We need the states that were active *in* the last batch returned by next_batch + # If fire_state was the last active state, we might need to add it if batcher didn't + # mark it complete. However, restore_original_order should handle all collected states + # correctly. + + # Restore original order and concatenate + final_states_list = batcher.restore_original_order(all_converged_states) + final_state_concatenated = ts.concatenate_states(final_states_list) + + end_time = time.perf_counter() + print(f"Finished {md_flavor} in {end_time - start_time:.2f} seconds.") + # Return both convergence steps and the final state object + return convergence_steps, final_state_concatenated + + +# --- Main Script --- +force_tol = 0.05 + +# Run with ase_fire +ase_steps, ase_final_state = run_optimization( + state.clone(), "ase_fire", force_tol=force_tol +) +# Run with vv_fire +vv_steps, vv_final_state = run_optimization(state.clone(), "vv_fire", force_tol=force_tol) + +print("\n--- Comparison ---") +print(f"{force_tol=:.2f} eV/Å") + +# Calculate Mean Position Displacements +ase_final_states_list = ase_final_state.split() +vv_final_states_list = vv_final_state.split() +mean_displacements = [] +for idx in range(len(ase_final_states_list)): + ase_pos = ase_final_states_list[idx].positions + vv_pos = vv_final_states_list[idx].positions + displacement = torch.norm(ase_pos - vv_pos, dim=1) + mean_disp = torch.mean(displacement).item() + mean_displacements.append(mean_disp) + + +print(f"Initial energies: {[f'{e.item():.3f}' for e in initial_energies]} eV") +print(f"Final ASE energies: {[f'{e.item():.3f}' for e in ase_final_state.energy]} eV") +print(f"Final VV energies: {[f'{e.item():.3f}' for e in vv_final_state.energy]} eV") +print(f"Mean Disp (ASE-VV): {[f'{d:.4f}' for d in mean_displacements]} Å") +print(f"Convergence steps (ASE FIRE): {ase_steps.tolist()}") +print(f"Convergence steps (VV FIRE): {vv_steps.tolist()}") + +# Identify structures that didn't converge +ase_not_converged = torch.where(ase_steps == -1)[0].tolist() +vv_not_converged = torch.where(vv_steps == -1)[0].tolist() + +if ase_not_converged: + print(f"ASE FIRE did not converge for indices: {ase_not_converged}") +if vv_not_converged: + print(f"VV FIRE did not converge for indices: {vv_not_converged}") diff --git a/pyproject.toml b/pyproject.toml index dc2824b2a..40b02ec4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,10 +139,5 @@ check-filenames = true ignore-words-list = ["convertor"] [tool.pytest.ini_options] -addopts = [ - "--cov-report=term-missing", - "--cov=torch_sim", - "-p no:warnings", - "-v", -] +addopts = ["-p no:warnings"] testpaths = ["tests"] diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 5547fc781..f4dce0610 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1,8 +1,17 @@ import copy +from dataclasses import fields +from typing import get_args +import pytest import torch from torch_sim.optimizers import ( + FireState, + FrechetCellFIREState, + GDState, + MdFlavor, + UnitCellFireState, + UnitCellGDState, fire, frechet_cell_fire, gradient_descent, @@ -43,13 +52,13 @@ def test_gradient_descent_optimization( # Check that energy decreased assert energies[-1] < energies[0], ( - f"FIRE optimization should reduce energy " + f"Gradient Descent optimization should reduce energy " f"(initial: {energies[0]}, final: {energies[-1]})" ) # Check force convergence max_force = torch.max(torch.norm(state.forces, dim=1)) - assert max_force < 0.2, f"Forces should be small after optimization (got {max_force})" + assert max_force < 0.2, f"Forces should be small after optimization, got {max_force=}" assert not torch.allclose(state.positions, initial_state.positions) @@ -86,7 +95,7 @@ def test_unit_cell_gradient_descent_optimization( # Check that energy decreased assert energies[-1] < energies[0], ( - f"FIRE optimization should reduce energy " + f"Gradient Descent optimization should reduce energy " f"(initial: {energies[0]}, final: {energies[-1]})" ) @@ -94,138 +103,309 @@ def test_unit_cell_gradient_descent_optimization( max_force = torch.max(torch.norm(state.forces, dim=1)) pressure = torch.trace(state.stress.squeeze(0)) / 3.0 assert pressure < 0.01, ( - f"Pressure should be small after optimization (got {pressure})" + f"Pressure should be small after optimization, got {pressure=}" ) - assert max_force < 0.2, f"Forces should be small after optimization (got {max_force})" + assert max_force < 0.2, f"Forces should be small after optimization, got {max_force=}" assert not torch.allclose(state.positions, initial_state.positions) assert not torch.allclose(state.cell, initial_state.cell) +@pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) def test_fire_optimization( - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor ) -> None: """Test that the FIRE optimizer actually minimizes energy.""" # Add some random displacement to positions - perturbed_positions = ( - ar_supercell_sim_state.positions + # Create a fresh copy for each test run to avoid interference + + current_positions = ( + ar_supercell_sim_state.positions.clone() + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 ) - ar_supercell_sim_state.positions = perturbed_positions - initial_state = ar_supercell_sim_state + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + batch=ar_supercell_sim_state.batch.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() # Initialize FIRE optimizer init_fn, update_fn = fire( model=lj_model, dt_max=0.3, dt_start=0.1, + md_flavor=md_flavor, ) - state = init_fn(ar_supercell_sim_state) + state = init_fn(current_sim_state) # Run optimization for a few steps energies = [1000, state.energy.item()] - while abs(energies[-2] - energies[-1]) > 1e-6: + max_steps = 1000 # Add max step to prevent infinite loop + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: state = update_fn(state) energies.append(state.energy.item()) + steps_taken += 1 + + if steps_taken == max_steps: + print(f"FIRE optimization for {md_flavor=} did not converge in {max_steps} steps") energies = energies[1:] # Check that energy decreased assert energies[-1] < energies[0], ( - f"FIRE optimization should reduce energy " + f"FIRE optimization for {md_flavor=} should reduce energy " f"(initial: {energies[0]}, final: {energies[-1]})" ) # Check force convergence max_force = torch.max(torch.norm(state.forces, dim=1)) - assert max_force < 0.2, f"Forces should be small after optimization (got {max_force})" + # bumped up the tolerance to 0.3 to account for the fact that ase_fire is more lenient + # in beginning steps + assert max_force < 0.3, ( + f"{md_flavor=} forces should be small after optimization, got {max_force=}" + ) - assert not torch.allclose(state.positions, initial_state.positions) + assert not torch.allclose(state.positions, initial_state_positions), ( + f"{md_flavor=} positions should have changed after optimization." + ) -def test_unit_cell_fire_optimization( +@pytest.mark.parametrize( + ("optimizer_fn", "expected_state_type"), + [(fire, FireState), (gradient_descent, GDState)], +) +def test_simple_optimizer_init_with_dict( + optimizer_fn: callable, + expected_state_type: type, + ar_supercell_sim_state: SimState, + lj_model: torch.nn.Module, +) -> None: + """Test simple optimizer init_fn with a SimState dictionary.""" + state_dict = { + f.name: getattr(ar_supercell_sim_state, f.name) + for f in fields(ar_supercell_sim_state) + } + init_fn, _ = optimizer_fn(model=lj_model) + opt_state = init_fn(state_dict) + assert isinstance(opt_state, expected_state_type) + assert opt_state.energy is not None + assert opt_state.forces is not None + + +@pytest.mark.parametrize("optimizer_func", [fire, unit_cell_fire, frechet_cell_fire]) +def test_optimizer_invalid_md_flavor( + optimizer_func: callable, lj_model: torch.nn.Module +) -> None: + """Test optimizer with an invalid md_flavor raises ValueError.""" + with pytest.raises(ValueError, match="Unknown md_flavor"): + optimizer_func(model=lj_model, md_flavor="invalid_flavor") + + +def test_fire_ase_negative_power_branch( ar_supercell_sim_state: SimState, lj_model: torch.nn.Module ) -> None: - """Test that the FIRE optimizer actually minimizes energy.""" - # Add some random displacement to positions - perturbed_positions = ( - ar_supercell_sim_state.positions - + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + """Test that the ASE FIRE P<0 branch behaves as expected.""" + f_dec = 0.5 # Default from fire optimizer + alpha_start = 0.1 # Default from fire optimizer + dt_start_val = 0.1 + + init_fn, update_fn = fire( + model=lj_model, + md_flavor="ase_fire", + f_dec=f_dec, + alpha_start=alpha_start, + dt_start=dt_start_val, + dt_max=1.0, + max_step=10.0, # Large max_step to not interfere with velocity check + ) + # Initialize state (forces are computed here) + state = init_fn(ar_supercell_sim_state) + + # Save parameters from initial state + initial_dt_batch = state.dt.clone() # per-batch dt + + # Manipulate state to ensure P < 0 for the update_fn step + # Ensure forces are non-trivial + state.forces += torch.sign(state.forces + 1e-6) * 1e-2 + state.forces[torch.abs(state.forces) < 1e-3] = 1e-3 + # Set velocities directly opposite to current forces + state.velocities = -state.forces * 0.1 # v = -k * F + + # Store forces that will be used in the power calculation and v += dt*F step + forces_at_power_calc = state.forces.clone() + + # Deepcopy state as update_fn modifies it in-place + state_to_update = copy.deepcopy(state) + updated_state = update_fn(state_to_update) + + # Assertions for P < 0 branch being taken + # Check for a single-batch state (ar_supercell_sim_state is single batch) + expected_dt_val = initial_dt_batch[0] * f_dec + assert torch.allclose(updated_state.dt[0], expected_dt_val) + assert torch.allclose( + updated_state.alpha[0], + torch.tensor( + alpha_start, + dtype=updated_state.alpha.dtype, + device=updated_state.alpha.device, + ), ) + assert updated_state.n_pos[0] == 0 - ar_supercell_sim_state.positions = perturbed_positions - initial_state = ar_supercell_sim_state + # Assertions for velocity update in ASE P < 0 case: + # v_after_mixing_is_0, then v_final = dt_new * F_at_power_calc + expected_final_velocities = ( + expected_dt_val * forces_at_power_calc[updated_state.batch == 0] + ) + assert torch.allclose( + updated_state.velocities[updated_state.batch == 0], + expected_final_velocities, + atol=1e-6, + ) - # Initialize FIRE optimizer - init_fn, update_fn = unit_cell_fire( + +def test_fire_vv_negative_power_branch( + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +) -> None: + """Attempt to trigger and test the VV FIRE P<0 branch.""" + f_dec = 0.5 + alpha_start = 0.1 + # Use a very large dt_start to encourage overshooting and P<0 inside _vv_fire_step + dt_start_val = 2.0 + dt_max_val = 2.0 + + init_fn, update_fn = fire( model=lj_model, - dt_max=0.3, - dt_start=0.1, + md_flavor="vv_fire", + f_dec=f_dec, + alpha_start=alpha_start, + dt_start=dt_start_val, + dt_max=dt_max_val, + n_min=0, # Allow dt to change immediately ) - state = init_fn(ar_supercell_sim_state) - # Run optimization for a few steps - energies = [1000, state.energy.item()] - while abs(energies[-2] - energies[-1]) > 1e-6: - state = update_fn(state) - energies.append(state.energy.item()) + initial_dt_batch = state.dt.clone() + initial_alpha_batch = state.alpha.clone() # Already alpha_start + initial_n_pos_batch = state.n_pos.clone() # Already 0 - energies = energies[1:] + state_to_update = copy.deepcopy(state) + updated_state = update_fn(state_to_update) - # Check that energy decreased - assert energies[-1] < energies[0], ( - f"FIRE optimization should reduce energy " - f"(initial: {energies[0]}, final: {energies[-1]})" + # Check if the P<0 branch was likely hit (params changed accordingly for batch 0) + expected_dt_val = initial_dt_batch[0] * f_dec + expected_alpha_val = torch.tensor( + alpha_start, + dtype=initial_alpha_batch.dtype, + device=initial_alpha_batch.device, ) - # Check force convergence - max_force = torch.max(torch.norm(state.forces, dim=1)) - pressure = torch.trace(state.stress.squeeze(0)) / 3.0 - assert pressure < 0.01, ( - f"Pressure should be small after optimization (got {pressure})" + p_lt_0_branch_taken = ( + torch.allclose(updated_state.dt[0], expected_dt_val) + and torch.allclose(updated_state.alpha[0], expected_alpha_val) + and updated_state.n_pos[0] == 0 ) - assert max_force < 0.2, f"Forces should be small after optimization (got {max_force})" - assert not torch.allclose(state.positions, initial_state.positions) - assert not torch.allclose(state.cell, initial_state.cell) + if not p_lt_0_branch_taken: + pytest.skip( + f"VV FIRE P<0 condition not reliably hit for batch 0. " + f"dt: {initial_dt_batch[0].item():.4f} -> {updated_state.dt[0].item():.4f} " + f"(expected factor {f_dec}). " + f"alpha: {initial_alpha_batch[0].item():.4f} -> " + f"{updated_state.alpha[0].item():.4f} (expected {alpha_start}). " + f"n_pos: {initial_n_pos_batch[0].item()} -> {updated_state.n_pos[0].item()} " + "(expected 0)." + ) + # If P<0 branch was taken, velocities should be zeroed + assert torch.allclose( + updated_state.velocities[updated_state.batch == 0], + torch.zeros_like(updated_state.velocities[updated_state.batch == 0]), + atol=1e-7, + ) -def test_unit_cell_frechet_fire_optimization( - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module + +@pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) +def test_unit_cell_fire_optimization( + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor ) -> None: - """Test that the FIRE optimizer actually minimizes energy.""" - # Add some random displacement to positions - perturbed_positions = ( - ar_supercell_sim_state.positions + """Test that the Unit Cell FIRE optimizer actually minimizes energy.""" + print(f"\n--- Starting test_unit_cell_fire_optimization for {md_flavor=} ---") + + # Add random displacement to positions and cell + current_positions = ( + ar_supercell_sim_state.positions.clone() + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 ) + current_cell = ( + ar_supercell_sim_state.cell.clone() + + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 + ) - ar_supercell_sim_state.positions = perturbed_positions - initial_state = ar_supercell_sim_state + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=current_cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + batch=ar_supercell_sim_state.batch.clone(), + ) + print(f"[{md_flavor}] Initial SimState created.") + + initial_state_positions = current_sim_state.positions.clone() + initial_state_cell = current_sim_state.cell.clone() # Initialize FIRE optimizer - init_fn, update_fn = frechet_cell_fire( + print(f"Initializing {md_flavor} optimizer...") + init_fn, update_fn = unit_cell_fire( model=lj_model, dt_max=0.3, dt_start=0.1, + md_flavor=md_flavor, ) + print(f"[{md_flavor}] Optimizer functions obtained.") - state = init_fn(ar_supercell_sim_state) + state = init_fn(current_sim_state) + energy = float(getattr(state, "energy", "nan")) + print(f"[{md_flavor}] Initial state created by init_fn. {energy=:.4f}") # Run optimization for a few steps - energies = [1000, state.energy.item()] - while abs(energies[-2] - energies[-1]) > 1e-6: + energies = [1000.0, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + print(f"[{md_flavor}] Entering optimization loop (max_steps: {max_steps})...") + + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: state = update_fn(state) energies.append(state.energy.item()) + steps_taken += 1 + + print(f"[{md_flavor}] Loop finished after {steps_taken} steps.") + + if steps_taken == max_steps and abs(energies[-2] - energies[-1]) > 1e-6: + print( + f"WARNING: Unit Cell FIRE {md_flavor=} optimization did not converge " + f"in {max_steps} steps. Final energy: {energies[-1]:.4f}" + ) + else: + print( + f"Unit Cell FIRE {md_flavor=} optimization converged in {steps_taken} " + f"steps. Final energy: {energies[-1]:.4f}" + ) energies = energies[1:] # Check that energy decreased assert energies[-1] < energies[0], ( - f"FIRE optimization should reduce energy " + f"Unit Cell FIRE {md_flavor=} optimization should reduce energy " f"(initial: {energies[0]}, final: {energies[-1]})" ) @@ -233,184 +413,341 @@ def test_unit_cell_frechet_fire_optimization( max_force = torch.max(torch.norm(state.forces, dim=1)) pressure = torch.trace(state.stress.squeeze(0)) / 3.0 assert pressure < 0.01, ( - f"Pressure should be small after optimization (got {pressure})" + f"Pressure should be small after optimization, got {pressure=}" + ) + assert max_force < 0.3, ( + f"{md_flavor=} forces should be small after optimization, got {max_force}" ) - assert max_force < 0.2, f"Forces should be small after optimization (got {max_force})" - assert not torch.allclose(state.positions, initial_state.positions) - assert not torch.allclose(state.cell, initial_state.cell) + assert not torch.allclose(state.positions, initial_state_positions), ( + f"{md_flavor=} positions should have changed after optimization." + ) + assert not torch.allclose(state.cell, initial_state_cell), ( + f"{md_flavor=} cell should have changed after optimization." + ) -def test_fire_multi_batch( - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +@pytest.mark.parametrize( + ("optimizer_fn", "expected_state_type", "cell_factor_val"), + [ + (unit_cell_fire, UnitCellFireState, 100), + (unit_cell_gradient_descent, UnitCellGDState, 50.0), + (frechet_cell_fire, FrechetCellFIREState, 75.0), + ], +) +def test_cell_optimizer_init_with_dict_and_cell_factor( + optimizer_fn: callable, + expected_state_type: type, + cell_factor_val: float, + ar_supercell_sim_state: SimState, + lj_model: torch.nn.Module, ) -> None: - """Test FIRE optimization with multiple batches.""" - # Create a multi-batch system by duplicating ar_fcc_state + """Test cell optimizer init_fn with dict state and explicit cell_factor.""" + state_dict = { + f.name: getattr(ar_supercell_sim_state, f.name) + for f in fields(ar_supercell_sim_state) + } + init_fn, _ = optimizer_fn(model=lj_model, cell_factor=cell_factor_val) + opt_state = init_fn(state_dict) + + assert isinstance(opt_state, expected_state_type) + assert opt_state.energy is not None + assert opt_state.forces is not None + assert opt_state.stress is not None + expected_cf_tensor = torch.full( + (opt_state.n_batches, 1, 1), + float(cell_factor_val), # Ensure float for comparison if int is passed + device=lj_model.device, + dtype=lj_model.dtype, + ) + assert torch.allclose(opt_state.cell_factor, expected_cf_tensor) + + +@pytest.mark.parametrize( + ("optimizer_fn", "expected_state_type"), + [ + (unit_cell_fire, UnitCellFireState), + (frechet_cell_fire, FrechetCellFIREState), + ], +) +def test_cell_optimizer_init_cell_factor_none( + optimizer_fn: callable, + expected_state_type: type, + ar_supercell_sim_state: SimState, + lj_model: torch.nn.Module, +) -> None: + """Test cell optimizer init_fn with cell_factor=None.""" + init_fn, _ = optimizer_fn(model=lj_model, cell_factor=None) + # Ensure n_batches > 0 for cell_factor calculation from counts + assert ar_supercell_sim_state.n_batches > 0 + opt_state = init_fn(ar_supercell_sim_state) # Uses SimState directly + assert isinstance(opt_state, expected_state_type) + _, counts = torch.unique(ar_supercell_sim_state.batch, return_counts=True) + expected_cf_tensor = counts.to(dtype=lj_model.dtype).view(-1, 1, 1) + assert torch.allclose(opt_state.cell_factor, expected_cf_tensor) + assert opt_state.energy is not None + assert opt_state.forces is not None + assert opt_state.stress is not None + + +@pytest.mark.filterwarnings("ignore:WARNING: Non-positive volume detected") +def test_unit_cell_fire_ase_non_positive_volume_warning( + ar_supercell_sim_state: SimState, + lj_model: torch.nn.Module, + capsys: pytest.CaptureFixture, +) -> None: + """Attempt to trigger non-positive volume warning in unit_cell_fire ASE.""" + # Use a state that might lead to cell inversion with aggressive steps + # Make a copy and slightly perturb the cell to make it prone to issues + perturbed_state = ar_supercell_sim_state.clone() + perturbed_state.cell += ( + torch.randn_like(perturbed_state.cell) * 0.5 + ) # Large perturbation + # Also ensure no PBC issues by slightly expanding cell if it got too small + if torch.linalg.det(perturbed_state.cell[0]) < 1.0: + perturbed_state.cell[0] *= 2.0 - generator = torch.Generator(device=ar_supercell_sim_state.device) + init_fn, update_fn = unit_cell_fire( + model=lj_model, + md_flavor="ase_fire", + dt_max=5.0, # Large dt + max_step=2.0, # Large max_step + dt_start=1.0, + f_dec=0.99, # Slow down dt decrease + alpha_start=0.99, # Aggressive alpha + ) + state = init_fn(perturbed_state) - ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) - ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + # Run a few steps hoping to trigger the warning + for _ in range(5): + state = update_fn(state) + if "WARNING: Non-positive volume detected" in capsys.readouterr().err: + break # Warning captured - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - generator.manual_seed(43) - state.positions += ( - torch.randn( - state.positions.shape, - device=state.device, - generator=generator, - ) - * 0.1 - ) + assert state is not None # Ensure optimizer ran - multi_state = concatenate_states( - [ar_supercell_sim_state_1, ar_supercell_sim_state_2], - device=ar_supercell_sim_state.device, + +@pytest.mark.parametrize("md_flavor", get_args(MdFlavor)) +def test_frechet_cell_fire_optimization( + ar_supercell_sim_state: SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor +) -> None: + """Test that the Frechet Cell FIRE optimizer actually minimizes energy for different + md_flavors.""" + print(f"\n--- Starting test_frechet_cell_fire_optimization for {md_flavor=} ---") + + # Add random displacement to positions and cell + # Create a fresh copy for each test run to avoid interference + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + current_cell = ( + ar_supercell_sim_state.cell.clone() + + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 ) + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=current_cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + batch=ar_supercell_sim_state.batch.clone(), + ) + print(f"[{md_flavor}] Initial SimState created for Frechet test.") + + initial_state_positions = current_sim_state.positions.clone() + initial_state_cell = current_sim_state.cell.clone() + # Initialize FIRE optimizer - init_fn, update_fn = fire( + print(f"Initializing Frechet {md_flavor} optimizer...") + init_fn, update_fn = frechet_cell_fire( model=lj_model, dt_max=0.3, dt_start=0.1, + md_flavor=md_flavor, ) + print(f"[{md_flavor}] Frechet optimizer functions obtained.") - state = init_fn(multi_state) - initial_state = copy.deepcopy(state) + state = init_fn(current_sim_state) + energy = float(getattr(state, "energy", "nan")) + print(f"[{md_flavor}] Initial state created by Frechet init_fn. {energy=:.4f}") # Run optimization for a few steps - prev_energy = torch.ones(2, device=state.device, dtype=state.energy.dtype) * 1000 - current_energy = initial_state.energy - step = 0 - while not torch.allclose(current_energy, prev_energy, atol=1e-9): - prev_energy = current_energy + energies = [1000.0, state.energy.item()] # Ensure float for comparison + max_steps = 1000 + steps_taken = 0 + print(f"[{md_flavor}] Entering Frechet optimization loop (max_steps: {max_steps})...") + + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: state = update_fn(state) - current_energy = state.energy + energies.append(state.energy.item()) + steps_taken += 1 - step += 1 - if step > 500: - raise ValueError("Optimization did not converge") + print(f"[{md_flavor}] Frechet loop finished after {steps_taken} steps.") - # check that we actually optimized - assert step > 10 + if steps_taken == max_steps and abs(energies[-2] - energies[-1]) > 1e-6: + print( + f"WARNING: Frechet Cell FIRE {md_flavor=} optimization did not converge " + f"in {max_steps} steps. Final energy: {energies[-1]:.4f}" + ) + else: + print( + f"Frechet Cell FIRE {md_flavor=} optimization converged in {steps_taken} " + f"steps. Final energy: {energies[-1]:.4f}" + ) - # Check that energy decreased for both batches - assert torch.all(state.energy < initial_state.energy), ( - "FIRE optimization should reduce energy for all batches" + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"Frechet FIRE {md_flavor=} optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" ) - # transfer the energy and force checks to the batched optimizer + # Check force convergence max_force = torch.max(torch.norm(state.forces, dim=1)) - assert torch.all(max_force < 0.1), ( - f"Forces should be small after optimization (got {max_force})" - ) + # Assumes single batch for this state stress access + pressure = torch.trace(state.stress.squeeze(0)) / 3.0 - n_ar_atoms = ar_supercell_sim_state.n_atoms - assert not torch.allclose( - state.positions[:n_ar_atoms], multi_state.positions[:n_ar_atoms] + # Adjust tolerances if needed, Frechet might behave slightly differently + pressure_tol = 0.01 + force_tol = 0.2 + + assert torch.abs(pressure) < pressure_tol, ( + f"{md_flavor=} pressure should be below {pressure_tol=} after Frechet " + f"optimization, got {pressure.item()}" ) - assert not torch.allclose( - state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] + assert max_force < force_tol, ( + f"{md_flavor=} forces should be below {force_tol=} after Frechet optimization, " + f"got {max_force}" ) - # we are evolving identical systems - assert current_energy[0] == current_energy[1] + assert not torch.allclose(state.positions, initial_state_positions, atol=1e-5), ( + f"{md_flavor=} positions should have changed after Frechet optimization." + ) + assert not torch.allclose(state.cell, initial_state_cell, atol=1e-5), ( + f"{md_flavor=} cell should have changed after Frechet optimization." + ) -def test_fire_batch_consistency( - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module +@pytest.mark.parametrize("optimizer_func", [fire, unit_cell_fire, frechet_cell_fire]) +def test_optimizer_batch_consistency( + optimizer_func: callable, + ar_supercell_sim_state: SimState, + lj_model: torch.nn.Module, ) -> None: - """Test batched FIRE optimization is consistent with individual optimizations.""" + """Test batched optimizer is consistent with individual optimizations.""" generator = torch.Generator(device=ar_supercell_sim_state.device) - ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) - ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + # Create two distinct initial states by cloning and perturbing + state1_orig = ar_supercell_sim_state.clone() - # Add same random perturbation to both states - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - generator.manual_seed(43) - state.positions += ( + # Apply identical perturbations to state1_orig + # for state_item in [state1_orig, state2_orig]: # Old loop structure + generator.manual_seed(43) # Reset seed for positions + state1_orig.positions += ( + torch.randn( + state1_orig.positions.shape, device=state1_orig.device, generator=generator + ) + * 0.1 + ) + if optimizer_func in (unit_cell_fire, frechet_cell_fire): + generator.manual_seed(44) # Reset seed for cell + state1_orig.cell += ( torch.randn( - state.positions.shape, - device=state.device, - generator=generator, + state1_orig.cell.shape, device=state1_orig.device, generator=generator ) - * 0.1 + * 0.01 ) - # Optimize each state individually + # Ensure state2_orig is identical to perturbed state1_orig + state2_orig = state1_orig.clone() + final_individual_states = [] - total_steps = [] - def energy_converged(current_energy: float, prev_energy: float) -> bool: - """Check if optimization should continue based on energy convergence.""" - return not torch.allclose(current_energy, prev_energy, atol=1e-6) + def energy_converged(current_e: torch.Tensor, prev_e: torch.Tensor) -> bool: + """Check for energy convergence (scalar energies).""" + return not torch.allclose(current_e, prev_e, atol=1e-6) - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, + for state_for_indiv_opt in [state1_orig.clone(), state2_orig.clone()]: + init_fn_indiv, update_fn_indiv = optimizer_func( + model=lj_model, dt_max=0.3, dt_start=0.1 ) + opt_state_indiv = init_fn_indiv(state_for_indiv_opt) - state_opt = init_fn(state) - - # Run optimization until convergence - current_energy = state_opt.energy - prev_energy = current_energy + 1 - - step = 0 - while energy_converged(current_energy, prev_energy): - prev_energy = current_energy - state_opt = update_fn(state_opt) - current_energy = state_opt.energy - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") - - final_individual_states.append(state_opt) - total_steps.append(step) + current_e_indiv = opt_state_indiv.energy + # Ensure prev_e_indiv is different to start the loop + prev_e_indiv = current_e_indiv + torch.tensor( + 1.0, device=current_e_indiv.device, dtype=current_e_indiv.dtype + ) - # Now optimize both states together in a batch - multi_state = concatenate_states( - [ - copy.deepcopy(ar_supercell_sim_state_1), - copy.deepcopy(ar_supercell_sim_state_2), - ], + steps_indiv = 0 + while energy_converged(current_e_indiv, prev_e_indiv): + prev_e_indiv = current_e_indiv + opt_state_indiv = update_fn_indiv(opt_state_indiv) + current_e_indiv = opt_state_indiv.energy + steps_indiv += 1 + if steps_indiv > 1000: + raise ValueError( + f"Individual opt for {optimizer_func.__name__} did not converge" + ) + final_individual_states.append(opt_state_indiv) + + # Batched optimization + multi_state_initial = concatenate_states( + [state1_orig.clone(), state2_orig.clone()], device=ar_supercell_sim_state.device, ) - init_fn, batch_update_fn = fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, + init_fn_batch, update_fn_batch = optimizer_func( + model=lj_model, dt_max=0.3, dt_start=0.1 ) + batch_opt_state = init_fn_batch(multi_state_initial) - batch_state = init_fn(multi_state) - - # Run optimization until convergence for both batches - current_energies = batch_state.energy.clone() - prev_energies = current_energies + 1 + current_energies_batch = batch_opt_state.energy.clone() + # Ensure prev_energies_batch requires update and has same shape + prev_energies_batch = current_energies_batch + torch.tensor( + 1.0, device=current_energies_batch.device, dtype=current_energies_batch.dtype + ) - step = 0 - while energy_converged(current_energies[0], prev_energies[0]) and energy_converged( - current_energies[1], prev_energies[1] - ): - prev_energies = current_energies.clone() - batch_state = batch_update_fn(batch_state) - current_energies = batch_state.energy.clone() - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") + steps_batch = 0 + # Converge when all batch energies have converged + while not torch.allclose(current_energies_batch, prev_energies_batch, atol=1e-6): + prev_energies_batch = current_energies_batch.clone() + batch_opt_state = update_fn_batch(batch_opt_state) + current_energies_batch = batch_opt_state.energy.clone() + steps_batch += 1 + if steps_batch > 1000: + raise ValueError( + f"Batched opt for {optimizer_func.__name__} did not converge" + ) - individual_energies = [state.energy.item() for state in final_individual_states] - # Check that final energies from batched optimization match individual optimizations - for step, individual_energy in enumerate(individual_energies): - assert abs(batch_state.energy[step].item() - individual_energy) < 1e-4, ( - f"Energy for batch {step} doesn't match individual optimization: " - f"batch={batch_state.energy[step].item()}, individual={individual_energy}" + individual_final_energies = [s.energy.item() for s in final_individual_states] + for idx, indiv_energy in enumerate(individual_final_energies): + assert abs(batch_opt_state.energy[idx].item() - indiv_energy) < 1e-4, ( + f"Energy batch {idx} ({optimizer_func.__name__}): " + f"{batch_opt_state.energy[idx].item()} vs indiv {indiv_energy}" ) + # Check positions changed for both parts of the batch + n_atoms_first_state = state1_orig.positions.shape[0] + assert not torch.allclose( + batch_opt_state.positions[:n_atoms_first_state], + multi_state_initial.positions[:n_atoms_first_state], + atol=1e-5, # Added tolerance as in original frechet test + ), f"{optimizer_func.__name__} positions batch 0 did not change." + assert not torch.allclose( + batch_opt_state.positions[n_atoms_first_state:], + multi_state_initial.positions[n_atoms_first_state:], + atol=1e-5, + ), f"{optimizer_func.__name__} positions batch 1 did not change." + + if optimizer_func in (unit_cell_fire, frechet_cell_fire): + assert not torch.allclose( + batch_opt_state.cell, multi_state_initial.cell, atol=1e-5 + ), f"{optimizer_func.__name__} cell did not change." + def test_unit_cell_fire_multi_batch( ar_supercell_sim_state: SimState, lj_model: torch.nn.Module @@ -473,19 +810,7 @@ def test_unit_cell_fire_multi_batch( # transfer the energy and force checks to the batched optimizer max_force = torch.max(torch.norm(state.forces, dim=1)) assert torch.all(max_force < 0.1), ( - f"Forces should be small after optimization (got {max_force})" - ) - - pressure_0 = torch.trace(state.stress[0]) / 3.0 - pressure_1 = torch.trace(state.stress[1]) / 3.0 - assert torch.allclose(pressure_0, pressure_1), ( - f"Pressure should be the same for all batches (got {pressure_0} and {pressure_1})" - ) - assert pressure_0 < 0.01, ( - f"Pressure should be small after optimization (got {pressure_0})" - ) - assert pressure_1 < 0.01, ( - f"Pressure should be small after optimization (got {pressure_1})" + f"Forces should be small after optimization, got {max_force=}" ) n_ar_atoms = ar_supercell_sim_state.n_atoms @@ -495,16 +820,16 @@ def test_unit_cell_fire_multi_batch( assert not torch.allclose( state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] ) - assert not torch.allclose(state.cell, multi_state.cell) # we are evolving identical systems assert current_energy[0] == current_energy[1] -def test_unit_cell_fire_batch_consistency( +def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 ar_supercell_sim_state: SimState, lj_model: torch.nn.Module ) -> None: - """Test batched FIRE optimization is consistent with individual optimizations.""" + """Test batched Frechet Fixed cell FIRE optimization is + consistent with FIRE (position only) optimizations.""" generator = torch.Generator(device=ar_supercell_sim_state.device) ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) @@ -514,17 +839,13 @@ def test_unit_cell_fire_batch_consistency( for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: generator.manual_seed(43) state.positions += ( - torch.randn( - state.positions.shape, - device=state.device, - generator=generator, - ) + torch.randn(state.positions.shape, device=state.device, generator=generator) * 0.1 ) # Optimize each state individually - final_individual_states = [] - total_steps = [] + final_individual_states_unit_cell = [] + total_steps_unit_cell = [] def energy_converged(current_energy: float, prev_energy: float) -> bool: """Check if optimization should continue based on energy convergence.""" @@ -535,6 +856,8 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: model=lj_model, dt_max=0.3, dt_start=0.1, + hydrostatic_strain=True, + constant_volume=True, ) state_opt = init_fn(state) @@ -552,307 +875,21 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: if step > 1000: raise ValueError("Optimization did not converge") - final_individual_states.append(state_opt) - total_steps.append(step) + final_individual_states_unit_cell.append(state_opt) + total_steps_unit_cell.append(step) - # Now optimize both states together in a batch - multi_state = concatenate_states( - [ - copy.deepcopy(ar_supercell_sim_state_1), - copy.deepcopy(ar_supercell_sim_state_2), - ], - device=ar_supercell_sim_state.device, - ) + # Optimize each state individually + final_individual_states_fire = [] + total_steps_fire = [] - init_fn, batch_update_fn = unit_cell_fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - ) + def energy_converged(current_energy: float, prev_energy: float) -> bool: + """Check if optimization should continue based on energy convergence.""" + return not torch.allclose(current_energy, prev_energy, atol=1e-6) - batch_state = init_fn(multi_state) + for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: + init_fn, update_fn = fire(model=lj_model, dt_max=0.3, dt_start=0.1) - # Run optimization until convergence for both batches - current_energies = batch_state.energy.clone() - prev_energies = current_energies + 1 - - step = 0 - while energy_converged(current_energies[0], prev_energies[0]) and energy_converged( - current_energies[1], prev_energies[1] - ): - prev_energies = current_energies.clone() - batch_state = batch_update_fn(batch_state) - current_energies = batch_state.energy.clone() - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") - - individual_energies = [state.energy.item() for state in final_individual_states] - # Check that final energies from batched optimization match individual optimizations - for step, individual_energy in enumerate(individual_energies): - assert abs(batch_state.energy[step].item() - individual_energy) < 1e-4, ( - f"Energy for batch {step} doesn't match individual optimization: " - f"batch={batch_state.energy[step].item()}, individual={individual_energy}" - ) - - -def test_unit_cell_frechet_fire_multi_batch( - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module -) -> None: - """Test FIRE optimization with multiple batches.""" - # Create a multi-batch system by duplicating ar_fcc_state - - generator = torch.Generator(device=ar_supercell_sim_state.device) - - ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) - ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) - - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - generator.manual_seed(43) - state.positions += ( - torch.randn( - state.positions.shape, - device=state.device, - generator=generator, - ) - * 0.1 - ) - - multi_state = concatenate_states( - [ar_supercell_sim_state_1, ar_supercell_sim_state_2], - device=ar_supercell_sim_state.device, - ) - - # Initialize FIRE optimizer - init_fn, update_fn = frechet_cell_fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - ) - - state = init_fn(multi_state) - initial_state = copy.deepcopy(state) - - # Run optimization for a few steps - prev_energy = torch.ones(2, device=state.device, dtype=state.energy.dtype) * 1000 - current_energy = initial_state.energy - step = 0 - while not torch.allclose(current_energy, prev_energy, atol=1e-9): - prev_energy = current_energy - state = update_fn(state) - current_energy = state.energy - - step += 1 - if step > 500: - raise ValueError("Optimization did not converge") - - # check that we actually optimized - assert step > 10 - - # Check that energy decreased for both batches - assert torch.all(state.energy < initial_state.energy), ( - "FIRE optimization should reduce energy for all batches" - ) - - # transfer the energy and force checks to the batched optimizer - max_force = torch.max(torch.norm(state.forces, dim=1)) - assert torch.all(max_force < 0.1), ( - f"Forces should be small after optimization (got {max_force})" - ) - - pressure_0 = torch.trace(state.stress[0]) / 3.0 - pressure_1 = torch.trace(state.stress[1]) / 3.0 - assert torch.allclose(pressure_0, pressure_1), ( - f"Pressure should be the same for all batches (got {pressure_0} and {pressure_1})" - ) - assert pressure_0 < 0.01, ( - f"Pressure should be small after optimization (got {pressure_0})" - ) - assert pressure_1 < 0.01, ( - f"Pressure should be small after optimization (got {pressure_1})" - ) - - n_ar_atoms = ar_supercell_sim_state.n_atoms - assert not torch.allclose( - state.positions[:n_ar_atoms], multi_state.positions[:n_ar_atoms] - ) - assert not torch.allclose( - state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] - ) - assert not torch.allclose(state.cell, multi_state.cell) - - # we are evolving identical systems - assert current_energy[0] == current_energy[1] - - -def test_unit_cell_frechet_fire_batch_consistency( - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module -) -> None: - """Test batched FIRE optimization is consistent with individual optimizations.""" - generator = torch.Generator(device=ar_supercell_sim_state.device) - - ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) - ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) - - # Add same random perturbation to both states - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - generator.manual_seed(43) - state.positions += ( - torch.randn( - state.positions.shape, - device=state.device, - generator=generator, - ) - * 0.1 - ) - - # Optimize each state individually - final_individual_states = [] - total_steps = [] - - def energy_converged(current_energy: float, prev_energy: float) -> bool: - """Check if optimization should continue based on energy convergence.""" - return not torch.allclose(current_energy, prev_energy, atol=1e-6) - - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = frechet_cell_fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - ) - - state_opt = init_fn(state) - - # Run optimization until convergence - current_energy = state_opt.energy - prev_energy = current_energy + 1 - - step = 0 - while energy_converged(current_energy, prev_energy): - prev_energy = current_energy - state_opt = update_fn(state_opt) - current_energy = state_opt.energy - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") - - final_individual_states.append(state_opt) - total_steps.append(step) - - # Now optimize both states together in a batch - multi_state = concatenate_states( - [ - copy.deepcopy(ar_supercell_sim_state_1), - copy.deepcopy(ar_supercell_sim_state_2), - ], - device=ar_supercell_sim_state.device, - ) - - init_fn, batch_update_fn = frechet_cell_fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - ) - - batch_state = init_fn(multi_state) - - # Run optimization until convergence for both batches - current_energies = batch_state.energy.clone() - prev_energies = current_energies + 1 - - step = 0 - while energy_converged(current_energies[0], prev_energies[0]) and energy_converged( - current_energies[1], prev_energies[1] - ): - prev_energies = current_energies.clone() - batch_state = batch_update_fn(batch_state) - current_energies = batch_state.energy.clone() - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") - - individual_energies = [state.energy.item() for state in final_individual_states] - # Check that final energies from batched optimization match individual optimizations - for step, individual_energy in enumerate(individual_energies): - assert abs(batch_state.energy[step].item() - individual_energy) < 1e-4, ( - f"Energy for batch {step} doesn't match individual optimization: " - f"batch={batch_state.energy[step].item()}, individual={individual_energy}" - ) - - -def test_fire_fixed_cell_frechet_consistency( # noqa: C901 - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module -) -> None: - """Test batched Frechet Fixed cell FIRE optimization is - consistent with FIRE (position only) optimizations.""" - generator = torch.Generator(device=ar_supercell_sim_state.device) - - ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) - ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) - - # Add same random perturbation to both states - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - generator.manual_seed(43) - state.positions += ( - torch.randn( - state.positions.shape, - device=state.device, - generator=generator, - ) - * 0.1 - ) - - # Optimize each state individually - final_individual_states_frechet = [] - total_steps_frechet = [] - - def energy_converged(current_energy: float, prev_energy: float) -> bool: - """Check if optimization should continue based on energy convergence.""" - return not torch.allclose(current_energy, prev_energy, atol=1e-6) - - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = unit_cell_fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - hydrostatic_strain=True, - constant_volume=True, - ) - - state_opt = init_fn(state) - - # Run optimization until convergence - current_energy = state_opt.energy - prev_energy = current_energy + 1 - - step = 0 - while energy_converged(current_energy, prev_energy): - prev_energy = current_energy - state_opt = update_fn(state_opt) - current_energy = state_opt.energy - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") - - final_individual_states_frechet.append(state_opt) - total_steps_frechet.append(step) - - # Optimize each state individually - final_individual_states_fire = [] - total_steps_fire = [] - - def energy_converged(current_energy: float, prev_energy: float) -> bool: - """Check if optimization should continue based on energy convergence.""" - return not torch.allclose(current_energy, prev_energy, atol=1e-6) - - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - ) - - state_opt = init_fn(state) + state_opt = init_fn(state) # Run optimization until convergence current_energy = state_opt.energy @@ -865,112 +902,7 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: current_energy = state_opt.energy step += 1 if step > 1000: - raise ValueError("Optimization did not converge") - - final_individual_states_fire.append(state_opt) - total_steps_fire.append(step) - - individual_energies_frechet = [ - state.energy.item() for state in final_individual_states_frechet - ] - individual_energies_fire = [ - state.energy.item() for state in final_individual_states_fire - ] - # Check that final energies from fixed cell optimization match - # position only optimizations - for step, energy_frechet in enumerate(individual_energies_frechet): - assert abs(energy_frechet - individual_energies_fire[step]) < 1e-4, ( - f"Energy for batch {step} doesn't match position only optimization: " - f"batch={energy_frechet}, individual={individual_energies_fire[step]}" - ) - - -def test_fire_fixed_cell_unit_cell_consistency( # noqa: C901 - ar_supercell_sim_state: SimState, lj_model: torch.nn.Module -) -> None: - """Test batched Frechet Fixed cell FIRE optimization is - consistent with FIRE (position only) optimizations.""" - generator = torch.Generator(device=ar_supercell_sim_state.device) - - ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) - ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) - - # Add same random perturbation to both states - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - generator.manual_seed(43) - state.positions += ( - torch.randn( - state.positions.shape, - device=state.device, - generator=generator, - ) - * 0.1 - ) - - # Optimize each state individually - final_individual_states_unit_cell = [] - total_steps_unit_cell = [] - - def energy_converged(current_energy: float, prev_energy: float) -> bool: - """Check if optimization should continue based on energy convergence.""" - return not torch.allclose(current_energy, prev_energy, atol=1e-6) - - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = unit_cell_fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - hydrostatic_strain=True, - constant_volume=True, - ) - - state_opt = init_fn(state) - - # Run optimization until convergence - current_energy = state_opt.energy - prev_energy = current_energy + 1 - - step = 0 - while energy_converged(current_energy, prev_energy): - prev_energy = current_energy - state_opt = update_fn(state_opt) - current_energy = state_opt.energy - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") - - final_individual_states_unit_cell.append(state_opt) - total_steps_unit_cell.append(step) - - # Optimize each state individually - final_individual_states_fire = [] - total_steps_fire = [] - - def energy_converged(current_energy: float, prev_energy: float) -> bool: - """Check if optimization should continue based on energy convergence.""" - return not torch.allclose(current_energy, prev_energy, atol=1e-6) - - for state in [ar_supercell_sim_state_1, ar_supercell_sim_state_2]: - init_fn, update_fn = fire( - model=lj_model, - dt_max=0.3, - dt_start=0.1, - ) - - state_opt = init_fn(state) - - # Run optimization until convergence - current_energy = state_opt.energy - prev_energy = current_energy + 1 - - step = 0 - while energy_converged(current_energy, prev_energy): - prev_energy = current_energy - state_opt = update_fn(state_opt) - current_energy = state_opt.energy - step += 1 - if step > 1000: - raise ValueError("Optimization did not converge") + raise ValueError(f"Optimization did not converge in {step=}") final_individual_states_fire.append(state_opt) total_steps_fire.append(step) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index a6ec58376..b16ab7502 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -12,11 +12,15 @@ * FIRE (Fast Inertial Relaxation Engine) optimization with unit cell parameters * FIRE optimization with Frechet cell parameterization for improved cell relaxation +ASE-style FIRE: https://gitlab.com/ase/ase/-/blob/master/ase/optimize/fire.py?ref_type=heads +Velocity Verlet-style FIRE: https://doi.org/10.1103/PhysRevLett.97.170201 + """ +import functools from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, Literal, get_args import torch @@ -25,6 +29,10 @@ from torch_sim.typing import StateDict +MdFlavor = Literal["vv_fire", "ase_fire"] +vv_fire_key, ase_fire_key = get_args(MdFlavor) + + @dataclass class GDState(SimState): """State class for batched gradient descent optimization. @@ -49,13 +57,8 @@ class GDState(SimState): def gradient_descent( - model: torch.nn.Module, - *, - lr: torch.Tensor | float = 0.01, -) -> tuple[ - Callable[[StateDict | SimState], GDState], - Callable[[GDState], GDState], -]: + model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01 +) -> tuple[Callable[[StateDict | SimState], GDState], Callable[[GDState], GDState]]: """Initialize a batched gradient descent optimization. Creates an optimizer that performs standard gradient descent on atomic positions @@ -489,8 +492,10 @@ def fire( f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, + max_step: float = 0.2, + md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ - FireState, + Callable[[SimState | StateDict], FireState], Callable[[FireState], FireState], ]: """Initialize a batched FIRE optimization. @@ -507,27 +512,43 @@ def fire( f_dec (float): Factor for timestep decrease when power is negative alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease + max_step (float): Maximum distance an atom can move per iteration (default + value is 0.2). Only used when md_flavor='ase_fire'. + md_flavor (MdFlavor): Optimization flavor, either "vv_fire" or "ase_fire". + Default is "ase_fire". Returns: - tuple: A pair of functions: + tuple[Callable, Callable]: - Initialization function that creates a FireState - - Update function that performs one FIRE optimization step + - Update function (either vv_fire_step or ase_fire_step) that performs + one FIRE optimization step. Notes: + - md_flavor="vv_fire" follows the original paper closely, including + integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 + and https://github.com/Radical-AI/torch-sim/issues/90#issuecomment-2826179997 + for details. + - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly + in the update steps and does not explicitly use atomic masses in the + velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 + for details. - FIRE is generally more efficient than standard gradient descent for atomic - structure optimization + structure optimization. - The algorithm adaptively adjusts step sizes and mixing parameters based - on the dot product of forces and velocities + on the dot product of forces and velocities (power). """ + if md_flavor not in get_args(MdFlavor): + raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") + device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 - # Setup parameters - params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min] - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [ - torch.as_tensor(p, device=device, dtype=dtype) for p in params - ] + # Setup parameters, added max_step for ASE style + dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( + torch.as_tensor(p, device=device, dtype=dtype) + for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) + ) def fire_init( state: SimState | StateDict, @@ -559,11 +580,9 @@ def fire_init( # Setup parameters dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) - # Create initial state - return FireState( + return FireState( # Create initial state # Copy SimState attributes positions=state.positions.clone(), masses=state.masses.clone(), @@ -581,97 +600,22 @@ def fire_init( n_pos=n_pos, ) - def fire_step( - state: FireState, - alpha_start: float = alpha_start, - dt_start: float = dt_start, - ) -> FireState: - """Perform one FIRE optimization step for batched atomic systems. - - Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for - optimizing atomic positions in a batched setting. Uses velocity Verlet - integration with adaptive velocity mixing. - - Args: - state: Current optimization state containing atomic parameters - alpha_start: Initial mixing parameter for velocity update - dt_start: Initial timestep for velocity Verlet integration - - Returns: - Updated state after performing one FIRE step - """ - n_batches = state.n_batches - - # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - - # Velocity Verlet first half step (v += 0.5*a*dt) - atom_wise_dt = state.dt[state.batch].unsqueeze(-1) - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) - - # Split positions and forces into atomic and cell components - atomic_positions = state.positions # shape: (n_atoms, 3) - - # Update atomic positions - atomic_positions_new = atomic_positions + atom_wise_dt * state.velocities - - # Update state with new positions and cell - state.positions = atomic_positions_new - - # Get new forces, energy, and stress - results = model(state) - state.energy = results["energy"] - state.forces = results["forces"] - - # Velocity Verlet first half step (v += 0.5*a*dt) - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) - - # Calculate power (F·V) for atoms - atomic_power = (state.forces * state.velocities).sum(dim=1) # [n_atoms] - atomic_power_per_batch = torch.zeros( - n_batches, device=device, dtype=atomic_power.dtype - ) - atomic_power_per_batch.scatter_add_( - dim=0, index=state.batch, src=atomic_power - ) # [n_batches] - - # Calculate power for cell DOFs - batch_power = atomic_power_per_batch - - for batch_idx in range(n_batches): - # FIRE specific updates - if batch_power[batch_idx] > 0: # Power is positive - state.n_pos[batch_idx] += 1 - if state.n_pos[batch_idx] > n_min: - state.dt[batch_idx] = min(state.dt[batch_idx] * f_inc, dt_max) - state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha - else: # Power is negative - state.n_pos[batch_idx] = 0 - state.dt[batch_idx] = state.dt[batch_idx] * f_dec - state.alpha[batch_idx] = alpha_start[batch_idx] - # Reset velocities for both atoms and cell - state.velocities[state.batch == batch_idx] = 0 - - # Mix velocity and force direction using FIRE for atoms - v_norm = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm = torch.norm(state.forces, dim=1, keepdim=True) - # Avoid division by zero - # mask = f_norm > 1e-10 - # state.velocity = torch.where( - # mask, - # (1.0 - state.alpha) * state.velocity - # + state.alpha * state.forces * v_norm / f_norm, - # state.velocity, - # ) - atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1) - state.velocities = ( - 1.0 - atom_wise_alpha - ) * state.velocities + atom_wise_alpha * state.forces * v_norm / (f_norm + eps) - - return state - - return fire_init, fire_step + step_func_kwargs = dict( + model=model, + dt_max=dt_max, + n_min=n_min, + f_inc=f_inc, + f_dec=f_dec, + alpha_start=alpha_start, + f_alpha=f_alpha, + eps=eps, + is_cell_optimization=False, + is_frechet=False, + ) + if md_flavor == ase_fire_key: + step_func_kwargs["max_step"] = max_step + step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] + return fire_init, functools.partial(step_func, **step_func_kwargs) @dataclass @@ -749,7 +693,7 @@ class UnitCellFireState(SimState, DeformGradMixin): n_pos: torch.Tensor -def unit_cell_fire( # noqa: C901, PLR0915 +def unit_cell_fire( model: torch.nn.Module, *, dt_max: float = 1.0, @@ -763,6 +707,8 @@ def unit_cell_fire( # noqa: C901, PLR0915 hydrostatic_strain: bool = False, constant_volume: bool = False, scalar_pressure: float = 0.0, + max_step: float = 0.2, + md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ UnitCellFireState, Callable[[UnitCellFireState], UnitCellFireState], @@ -789,6 +735,9 @@ def unit_cell_fire( # noqa: C901, PLR0915 (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization scalar_pressure (float): Applied external pressure in GPa + max_step (float): Maximum allowed step size for ase_fire + md_flavor (MdFlavor): Optimization flavor, either "vv_fire" or "ase_fire". + Default is "ase_fire". Returns: tuple: A pair of functions: @@ -796,6 +745,14 @@ def unit_cell_fire( # noqa: C901, PLR0915 - Update function that performs one FIRE optimization step Notes: + - md_flavor="vv_fire" follows the original paper closely, including + integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 + and https://github.com/Radical-AI/torch-sim/issues/90#issuecomment-2826179997 + for details. + - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly + in the update steps and does not explicitly use atomic masses in the + velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 + for details. - FIRE is generally more efficient than standard gradient descent for atomic structure optimization - The algorithm adaptively adjusts step sizes and mixing parameters based @@ -805,15 +762,17 @@ def unit_cell_fire( # noqa: C901, PLR0915 - The cell_factor parameter controls the relative scale of atomic vs cell optimization """ + if md_flavor not in get_args(MdFlavor): + raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters - params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min] - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [ - torch.as_tensor(p, device=device, dtype=dtype) for p in params - ] + dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( + torch.as_tensor(p, device=device, dtype=dtype) + for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) + ) def fire_init( state: SimState | StateDict, @@ -896,11 +855,9 @@ def fire_init( # Setup parameters dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) - # Create initial state - return UnitCellFireState( + return UnitCellFireState( # Create initial state # Copy SimState attributes positions=state.positions.clone(), masses=state.masses.clone(), @@ -929,157 +886,22 @@ def fire_init( constant_volume=constant_volume, ) - def fire_step( # noqa: PLR0915 - state: UnitCellFireState, - alpha_start: float = alpha_start, - dt_start: float = dt_start, - ) -> UnitCellFireState: - """Perform one FIRE optimization step for batched atomic systems with unit cell - optimization. - - Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for - optimizing atomic positions and unit cell parameters in a batched setting. Uses - velocity Verlet integration with adaptive velocity mixing. - - Args: - state: Current optimization state containing atomic and cell parameters - alpha_start: Initial mixing parameter for velocity update - dt_start: Initial timestep for velocity Verlet integration - - Returns: - Updated state after performing one FIRE step - """ - n_batches = state.n_batches - - # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) - - # Calculate current deformation gradient - cur_deform_grad = torch.transpose( - torch.linalg.solve(state.reference_cell, state.cell), 1, 2 - ) # shape: (n_batches, 3, 3) - - # Calculate cell positions from deformation gradient - cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) - cell_positions = cur_deform_grad * cell_factor_expanded - - # Velocity Verlet first half step (v += 0.5*a*dt) - atom_wise_dt = state.dt[state.batch].unsqueeze(-1) - cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) - - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) - state.cell_velocities += ( - 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) - ) - - # Split positions and forces into atomic and cell components - atomic_positions = state.positions # shape: (n_atoms, 3) - - # Update atomic and cell positions - atomic_positions_new = atomic_positions + atom_wise_dt * state.velocities - cell_positions_new = cell_positions + cell_wise_dt * state.cell_velocities - - # Update cell with deformation gradient - cell_update = cell_positions_new / cell_factor_expanded - new_cell = torch.bmm(state.reference_cell, cell_update.transpose(1, 2)) - - # Update state with new positions and cell - state.positions = atomic_positions_new - state.cell_positions = cell_positions_new - state.cell = new_cell - - # Get new forces, energy, and stress - results = model(state) - state.energy = results["energy"] - forces = results["forces"] - stress = results["stress"] - - state.forces = forces - state.stress = stress - # Calculate virial - volumes = torch.linalg.det(new_cell).view(-1, 1, 1) - virial = -volumes * (stress + state.pressure) - if state.hydrostatic_strain: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( - 0 - ).expand(n_batches, -1, -1) - if state.constant_volume: - diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device - ).unsqueeze(0).expand(n_batches, -1, -1) - - state.cell_forces = virial / state.cell_factor - - # Velocity Verlet first half step (v += 0.5*a*dt) - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) - state.cell_velocities += ( - 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) - ) - - # Calculate power (F·V) for atoms - atomic_power = (state.forces * state.velocities).sum(dim=1) # [n_atoms] - atomic_power_per_batch = torch.zeros( - n_batches, device=device, dtype=atomic_power.dtype - ) - atomic_power_per_batch.scatter_add_( - dim=0, index=state.batch, src=atomic_power - ) # [n_batches] - - # Calculate power for cell DOFs - cell_power = (state.cell_forces * state.cell_velocities).sum( - dim=(1, 2) - ) # [n_batches] - batch_power = atomic_power_per_batch + cell_power - - for batch_idx in range(n_batches): - # FIRE specific updates - if batch_power[batch_idx] > 0: # Power is positive - state.n_pos[batch_idx] += 1 - if state.n_pos[batch_idx] > n_min: - state.dt[batch_idx] = min(state.dt[batch_idx] * f_inc, dt_max) - state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha - else: # Power is negative - state.n_pos[batch_idx] = 0 - state.dt[batch_idx] = state.dt[batch_idx] * f_dec - state.alpha[batch_idx] = alpha_start[batch_idx] - # Reset velocities for both atoms and cell - state.velocities[state.batch == batch_idx] = 0 - state.cell_velocities[batch_idx] = 0 - - # Mix velocity and force direction using FIRE for atoms - v_norm = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm = torch.norm(state.forces, dim=1, keepdim=True) - # Avoid division by zero - # mask = f_norm > 1e-10 - # state.velocity = torch.where( - # mask, - # (1.0 - state.alpha) * state.velocity - # + state.alpha * state.forces * v_norm / f_norm, - # state.velocity, - # ) - batch_wise_alpha = state.alpha[state.batch].unsqueeze(-1) - state.velocities = ( - 1.0 - batch_wise_alpha - ) * state.velocities + batch_wise_alpha * state.forces * v_norm / (f_norm + eps) - - # Mix velocity and force direction for cell DOFs - cell_v_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cell_f_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cell_wise_alpha = state.alpha.unsqueeze(-1).unsqueeze(-1) - cell_mask = cell_f_norm > eps - state.cell_velocities = torch.where( - cell_mask, - (1.0 - cell_wise_alpha) * state.cell_velocities - + cell_wise_alpha * state.cell_forces * cell_v_norm / cell_f_norm, - state.cell_velocities, - ) - - return state - - return fire_init, fire_step + step_func_kwargs = dict( + model=model, + dt_max=dt_max, + n_min=n_min, + f_inc=f_inc, + f_dec=f_dec, + alpha_start=alpha_start, + f_alpha=f_alpha, + eps=eps, + is_cell_optimization=True, + is_frechet=False, + ) + if md_flavor == ase_fire_key: + step_func_kwargs["max_step"] = max_step + step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] + return fire_init, functools.partial(step_func, **step_func_kwargs) @dataclass @@ -1157,7 +979,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): n_pos: torch.Tensor -def frechet_cell_fire( # noqa: C901, PLR0915 +def frechet_cell_fire( model: torch.nn.Module, *, dt_max: float = 1.0, @@ -1171,6 +993,8 @@ def frechet_cell_fire( # noqa: C901, PLR0915 hydrostatic_strain: bool = False, constant_volume: bool = False, scalar_pressure: float = 0.0, + max_step: float = 0.2, + md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ FrechetCellFIREState, Callable[[FrechetCellFIREState], FrechetCellFIREState], @@ -1198,6 +1022,9 @@ def frechet_cell_fire( # noqa: C901, PLR0915 (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization scalar_pressure (float): Applied external pressure in GPa + max_step (float): Maximum allowed step size for ase_fire + md_flavor (MdFlavor): Optimization flavor, either "vv_fire" or "ase_fire". + Default is "ase_fire". Returns: tuple: A pair of functions: @@ -1205,6 +1032,14 @@ def frechet_cell_fire( # noqa: C901, PLR0915 - Update function that performs one FIRE step with Frechet derivatives Notes: + - md_flavor="vv_fire" follows the original paper closely, including + integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 + and https://github.com/Radical-AI/torch-sim/issues/90#issuecomment-2826179997 + for details. + - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly + in the update steps and does not explicitly use atomic masses in the + velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 + for details. - Frechet cell parameterization uses matrix logarithm to represent cell deformations, which provides improved numerical properties for cell optimization @@ -1213,15 +1048,17 @@ def frechet_cell_fire( # noqa: C901, PLR0915 - To fix the cell and only optimize atomic positions, set both constant_volume=True and hydrostatic_strain=True """ + if md_flavor not in get_args(MdFlavor): + raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters - params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min] - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [ - torch.as_tensor(p, device=device, dtype=dtype) for p in params - ] + dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( + torch.as_tensor(p, device=device, dtype=dtype) + for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) + ) def fire_init( state: SimState | StateDict, @@ -1321,8 +1158,7 @@ def fire_init( alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) - # Create initial state - return FrechetCellFIREState( + return FrechetCellFIREState( # Create initial state # Copy SimState attributes positions=state.positions, masses=state.masses, @@ -1351,195 +1187,470 @@ def fire_init( constant_volume=constant_volume, ) - def fire_step( # noqa: PLR0915 - state: FrechetCellFIREState, - alpha_start: float = alpha_start, - dt_start: float = dt_start, - ) -> FrechetCellFIREState: - """Perform one FIRE optimization step for batched atomic systems with - Frechet cell parameterization. - - Implements one step of the Fast Inertial Relaxation Engine (FIRE) - algorithm for optimizing atomic positions and unit cell parameters - using matrix logarithm parameterization for the cell degrees of freedom. - - Args: - state: Current optimization state containing atomic and cell parameters - alpha_start: Initial mixing parameter for velocity update - dt_start: Initial timestep for velocity Verlet integration - - Returns: - Updated state after performing one FIRE step with Frechet cell derivatives - """ - n_batches = state.n_batches + step_func_kwargs = dict( + model=model, + dt_max=dt_max, + n_min=n_min, + f_inc=f_inc, + f_dec=f_dec, + alpha_start=alpha_start, + f_alpha=f_alpha, + eps=eps, + is_cell_optimization=True, + is_frechet=True, + ) + if md_flavor == ase_fire_key: + step_func_kwargs["max_step"] = max_step + step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] + return fire_init, functools.partial(step_func, **step_func_kwargs) + + +def _vv_fire_step( # noqa: C901, PLR0915 + state: FireState | UnitCellFireState | FrechetCellFIREState, + model: torch.nn.Module, + *, + dt_max: torch.Tensor, + n_min: torch.Tensor, + f_inc: torch.Tensor, + f_dec: torch.Tensor, + alpha_start: torch.Tensor, + f_alpha: torch.Tensor, + eps: float, + is_cell_optimization: bool = False, + is_frechet: bool = False, +) -> FireState | UnitCellFireState | FrechetCellFIREState: + """Perform one Velocity-Verlet based FIRE optimization step. + + Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for + optimizing atomic positions and optionally unit cell parameters in a batched setting. + Uses velocity Verlet integration with adaptive velocity mixing. - # Setup parameters - dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) - alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) + Args: + state: Current optimization state (FireState, UnitCellFireState, or + FrechetCellFIREState). + model: Model that computes energies, forces, and potentially stress. + dt_max: Maximum allowed timestep. + n_min: Minimum steps before timestep increase. + f_inc: Factor for timestep increase when power is positive. + f_dec: Factor for timestep decrease when power is negative. + alpha_start: Initial mixing parameter for velocity update. + f_alpha: Factor for mixing parameter decrease. + eps: Small epsilon value for numerical stability. + is_cell_optimization: Flag indicating if cell optimization is active. + is_frechet: Flag indicating if Frechet cell parameterization is used. - # Calculate current deformation gradient - cur_deform_grad = state.deform_grad() # shape: (n_batches, 3, 3) + Returns: + Updated state after performing one VV-FIRE step. + """ + n_batches = state.n_batches + device = state.positions.device + dtype = state.positions.dtype + deform_grad_new: torch.Tensor | None = None - # Calculate log of deformation gradient - deform_grad_log = torch.zeros_like(cur_deform_grad) - for b in range(n_batches): - deform_grad_log[b] = tsm.matrix_log_33(cur_deform_grad[b]) + alpha_start_batch = torch.full( + (n_batches,), alpha_start.item(), device=device, dtype=dtype + ) - # Scale to get cell positions - cell_positions = deform_grad_log * state.cell_factor + atom_wise_dt = state.dt[state.batch].unsqueeze(-1) + state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) - # Velocity Verlet first half step (v += 0.5*a*dt) - atom_wise_dt = state.dt[state.batch].unsqueeze(-1) + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) - - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - # Split positions and forces into atomic and cell components - atomic_positions = state.positions # shape: (n_atoms, 3) - - # Update atomic and cell positions - atomic_positions_new = atomic_positions + atom_wise_dt * state.velocities - cell_positions_new = cell_positions + cell_wise_dt * state.cell_velocities - - # Convert cell positions to deformation gradient - deform_grad_log_new = cell_positions_new / state.cell_factor - - # deform_grad_new = torch.zeros_like(deform_grad_log_new) - # for b in range(n_batches): - # deform_grad_new[b] = expm.apply(deform_grad_log_new[b]) - - deform_grad_new = torch.matrix_exp(deform_grad_log_new) - - # Update cell with deformation gradient - new_row_vector_cell = torch.bmm( - state.reference_row_vector_cell, deform_grad_new.transpose(1, 2) - ) - - # Update state with new positions and cell - state.positions = atomic_positions_new - state.row_vector_cell = new_row_vector_cell - state.cell_positions = cell_positions_new - - # Get new forces and energy - results = model(state) - state.energy = results["energy"] + state.positions = state.positions + atom_wise_dt * state.velocities + + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) + if is_frechet: + assert isinstance(state, FrechetCellFIREState) + cur_deform_grad = state.deform_grad() + deform_grad_log = torch.zeros_like(cur_deform_grad) + for b in range(n_batches): + deform_grad_log[b] = tsm.matrix_log_33(cur_deform_grad[b]) + + cell_positions_log_scaled = deform_grad_log * cell_factor_reshaped + cell_positions_log_scaled_new = ( + cell_positions_log_scaled + cell_wise_dt * state.cell_velocities + ) + deform_grad_log_new = cell_positions_log_scaled_new / cell_factor_reshaped + deform_grad_new = torch.matrix_exp(deform_grad_log_new) + new_row_vector_cell = torch.bmm( + state.reference_row_vector_cell, deform_grad_new.transpose(1, 2) + ) + state.row_vector_cell = new_row_vector_cell + state.cell_positions = cell_positions_log_scaled_new + else: + assert isinstance(state, UnitCellFireState) + cur_deform_grad = state.deform_grad() + # cell_factor is (N,1,1) + cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) + current_cell_positions_scaled = ( + cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded + ) - # Combine new atomic forces and cell forces - forces = results["forces"] - stress = results["stress"] + cell_positions_scaled_new = ( + current_cell_positions_scaled + cell_wise_dt * state.cell_velocities + ) + cell_update = cell_positions_scaled_new / cell_factor_expanded + new_cell = torch.bmm( + state.reference_row_vector_cell, cell_update.transpose(1, 2) + ) + state.row_vector_cell = new_cell + state.cell_positions = cell_positions_scaled_new - state.forces = forces - state.stress = stress + results = model(state) + state.forces = results["forces"] + state.energy = results["energy"] - # Calculate virial + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + state.stress = results["stress"] volumes = torch.linalg.det(state.cell).view(-1, 1, 1) - virial = -volumes * (stress + state.pressure) # P is P_ext * I + virial = -volumes * (state.stress + state.pressure) + if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( - 0 - ).expand(n_batches, -1, -1) + virial = diag_mean.unsqueeze(-1) * torch.eye( + 3, device=device, dtype=dtype + ).unsqueeze(0).expand(n_batches, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device + 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) - # Perform batched matrix multiplication - ucf_cell_grad = torch.bmm( - virial, torch.linalg.inv(torch.transpose(deform_grad_new, 1, 2)) - ) - - # Pre-compute all 9 direction matrices - directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) - for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): - directions[idx, mu, nu] = 1.0 - - # Calculate cell forces batch by batch - cell_forces = torch.zeros_like(ucf_cell_grad) - for b in range(n_batches): - # Calculate all 9 Frechet derivatives at once - expm_derivs = torch.stack( - [ - tsm.expm_frechet( - deform_grad_log_new[b], direction, compute_expm=False - ) - for direction in directions - ] + if is_frechet: + assert isinstance(state, FrechetCellFIREState) + ucf_cell_grad = torch.bmm( + virial, torch.linalg.inv(torch.transpose(deform_grad_new, 1, 2)) ) - - # Calculate all 9 cell forces components - forces_flat = torch.sum( - expm_derivs * ucf_cell_grad[b].unsqueeze(0), dim=(1, 2) - ) - cell_forces[b] = forces_flat.reshape(3, 3) - - # Scale by cell_factor - cell_forces = cell_forces / state.cell_factor - state.cell_forces = cell_forces - - # Velocity Verlet second half step (v += 0.5*a*dt) - state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) + directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) + for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): + directions[idx, mu, nu] = 1.0 + + new_cell_forces = torch.zeros_like(ucf_cell_grad) + for b in range(n_batches): + expm_derivs = torch.stack( + [ + tsm.expm_frechet( + deform_grad_log_new[b], direction, compute_expm=False + ) + for direction in directions + ] + ) + forces_flat = torch.sum( + expm_derivs * ucf_cell_grad[b].unsqueeze(0), dim=(1, 2) + ) + new_cell_forces[b] = forces_flat.reshape(3, 3) + state.cell_forces = new_cell_forces / cell_factor_reshaped + else: + assert isinstance(state, UnitCellFireState) + state.cell_forces = virial / cell_factor_reshaped + + state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - # Calculate power (F·V) for atoms - atomic_power = (state.forces * state.velocities).sum(dim=1) # [n_atoms] - atomic_power_per_batch = torch.zeros( - n_batches, device=device, dtype=atomic_power.dtype - ) - atomic_power_per_batch.scatter_add_( - dim=0, index=state.batch, src=atomic_power - ) # [n_batches] - - # Calculate power for cell DOFs - cell_power = (state.cell_forces * state.cell_velocities).sum( - dim=(1, 2) - ) # [n_batches] - batch_power = atomic_power_per_batch + cell_power - - # FIRE updates for each batch - for batch_idx in range(n_batches): - # FIRE specific updates - if batch_power[batch_idx] > 0: - # Power is positive - state.n_pos[batch_idx] += 1 - if state.n_pos[batch_idx] > n_min: - state.dt[batch_idx] = min(state.dt[batch_idx] * f_inc, dt_max) - state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha - else: - # Power is negative - state.n_pos[batch_idx] = 0 - state.dt[batch_idx] = state.dt[batch_idx] * f_dec - state.alpha[batch_idx] = alpha_start[batch_idx] - # Reset velocities for both atoms and cell - state.velocities[state.batch == batch_idx] = 0 + atomic_power = (state.forces * state.velocities).sum(dim=1) + atomic_power_per_batch = torch.zeros( + n_batches, device=device, dtype=atomic_power.dtype + ) + atomic_power_per_batch.scatter_add_(dim=0, index=state.batch, src=atomic_power) + batch_power = atomic_power_per_batch + + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + batch_power += cell_power + + for batch_idx in range(n_batches): + if batch_power[batch_idx] > 0: + state.n_pos[batch_idx] += 1 + if state.n_pos[batch_idx] > n_min: + state.dt[batch_idx] = torch.minimum(state.dt[batch_idx] * f_inc, dt_max) + state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha + else: + state.n_pos[batch_idx] = 0 + state.dt[batch_idx] = state.dt[batch_idx] * f_dec + state.alpha[batch_idx] = alpha_start_batch[batch_idx] + state.velocities[state.batch == batch_idx] = 0 + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities[batch_idx] = 0 - # Mix velocity and force direction using FIRE for atoms - v_norm = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm = torch.norm(state.forces, dim=1, keepdim=True) - batch_wise_alpha = state.alpha[state.batch].unsqueeze(-1) - state.velocities = ( - 1.0 - batch_wise_alpha - ) * state.velocities + batch_wise_alpha * state.forces * v_norm / (f_norm + eps) + v_norm = torch.norm(state.velocities, dim=1, keepdim=True) + f_norm = torch.norm(state.forces, dim=1, keepdim=True) + atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1) + state.velocities = (1.0 - atom_wise_alpha) * state.velocities + ( + atom_wise_alpha * state.forces * v_norm / (f_norm + eps) + ) - # Mix velocity and force direction for cell DOFs + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_v_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) cell_f_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) cell_wise_alpha = state.alpha.unsqueeze(-1).unsqueeze(-1) - cell_mask = cell_f_norm > eps + cell_mask = (cell_f_norm > eps).expand_as(state.cell_velocities) state.cell_velocities = torch.where( cell_mask, (1.0 - cell_wise_alpha) * state.cell_velocities - + cell_wise_alpha * state.cell_forces * cell_v_norm / cell_f_norm, + + cell_wise_alpha * state.cell_forces * cell_v_norm / (cell_f_norm + eps), state.cell_velocities, ) - return state + return state + + +def _ase_fire_step( # noqa: C901, PLR0915 + state: FireState | UnitCellFireState | FrechetCellFIREState, + model: torch.nn.Module, + *, + dt_max: torch.Tensor, + n_min: torch.Tensor, + f_inc: torch.Tensor, + f_dec: torch.Tensor, + alpha_start: torch.Tensor, + f_alpha: torch.Tensor, + max_step: torch.Tensor, + eps: float, + is_cell_optimization: bool = False, + is_frechet: bool = False, +) -> FireState | UnitCellFireState | FrechetCellFIREState: + """Perform one ASE-style FIRE optimization step. + + Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm + mimicking the ASE implementation. It can handle atomic position optimization + only, or combined position and cell optimization (standard or Frechet). + + Args: + state: Current optimization state. + model: Model that computes energies, forces, and potentially stress. + dt_max: Maximum allowed timestep. + n_min: Minimum steps before timestep increase. + f_inc: Factor for timestep increase when power is positive. + f_dec: Factor for timestep decrease when power is negative. + alpha_start: Initial mixing parameter for velocity update. + f_alpha: Factor for mixing parameter decrease. + max_step: Maximum allowed step size. + eps: Small epsilon value for numerical stability. + is_cell_optimization: Flag indicating if cell optimization is active. + is_frechet: Flag indicating if Frechet cell parameterization is used. - return fire_init, fire_step + Returns: + Updated state after performing one ASE-FIRE step. + """ + device, dtype = state.positions.device, state.positions.dtype + n_batches = state.n_batches + + # Setup batch-wise alpha_start for potential reset + # alpha_start is a 0-dim tensor from the factory + alpha_start_batch = torch.full( + (n_batches,), alpha_start.item(), device=device, dtype=dtype + ) + + # 1. Current power (F·v) per batch (atoms + cell) + atomic_power = (state.forces * state.velocities).sum(dim=1) + batch_power = torch.zeros(n_batches, device=device, dtype=dtype) + batch_power.scatter_add_(0, state.batch, atomic_power) + + if is_cell_optimization: + valid_states = (UnitCellFireState, FrechetCellFIREState) + assert isinstance(state, valid_states), ( + f"Cell optimization requires one of {valid_states}." + ) + cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + batch_power += cell_power + + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch + + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 + + # 3. Velocity mixing BEFORE acceleration (ASE ordering) + # Atoms + v_norm_atom = torch.norm(state.velocities, dim=1, keepdim=True) + f_norm_atom = torch.norm(state.forces, dim=1, keepdim=True) + f_unit_atom = state.forces / (f_norm_atom + eps) + alpha_atom = state.alpha[state.batch].unsqueeze(-1) + pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) + v_new_atom = ( + 1.0 - alpha_atom + ) * state.velocities + alpha_atom * f_unit_atom * v_norm_atom + state.velocities = torch.where( + pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities) + ) + + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + # Cell velocity mixing + cv_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) + cf_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) + cf_unit = state.cell_forces / (cf_norm + eps) + alpha_cell_bc = state.alpha.view(-1, 1, 1) + pos_mask_cell_bc = pos_mask_batch.view(-1, 1, 1) + v_new_cell = ( + 1.0 - alpha_cell_bc + ) * state.cell_velocities + alpha_cell_bc * cf_unit * cv_norm + state.cell_velocities = torch.where( + pos_mask_cell_bc, v_new_cell, torch.zeros_like(state.cell_velocities) + ) + + # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) + atom_dt = state.dt[state.batch].unsqueeze(-1) + state.velocities += atom_dt * state.forces + + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + cell_dt = state.dt.view(-1, 1, 1) + state.cell_velocities += cell_dt * state.cell_forces + + # 5. Displacements + dr_atom = atom_dt * state.velocities + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + dr_cell = cell_dt * state.cell_velocities + + # 6. Clamp to max_step + # Atoms + dr_norm_atom = torch.norm(dr_atom, dim=1, keepdim=True) + mask_atom_max_step = dr_norm_atom > max_step + dr_atom = torch.where( + mask_atom_max_step, max_step * dr_atom / (dr_norm_atom + eps), dr_atom + ) + + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + # Cell clamp to max_step (Frobenius norm) + dr_cell_norm_fro = torch.norm(dr_cell.view(n_batches, -1), dim=1, keepdim=True) + mask_cell_max_step = dr_cell_norm_fro.view(n_batches, 1, 1) > max_step + dr_cell = torch.where( + mask_cell_max_step, + max_step * dr_cell / (dr_cell_norm_fro.view(n_batches, 1, 1) + eps), + dr_cell, + ) + + # 7. Position / cell update + state.positions = state.positions + dr_atom + + # F_new stores F_new for Frechet's ucf_cell_grad if needed + F_new: torch.Tensor | None = None + # logm_F_new stores logm_F_new for Frechet's cell_forces recalc if needed + logm_F_new: torch.Tensor | None = None + + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + if is_frechet: + assert isinstance(state, FrechetCellFIREState) + # Frechet cell update logic + new_logm_F_scaled = state.cell_positions + dr_cell + state.cell_positions = new_logm_F_scaled + # cell_factor is (N,1,1) + logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) + F_new = torch.matrix_exp(logm_F_new) + new_row_vector_cell = torch.bmm( + state.reference_row_vector_cell, F_new.transpose(-2, -1) + ) + state.row_vector_cell = new_row_vector_cell + else: # UnitCellFire + assert isinstance(state, UnitCellFireState) + # Unit cell update logic + F_current = state.deform_grad() + # state.cell_factor is (N,1,1), F_current is (N,3,3) + # cell_factor_exp for element-wise F_current * cell_factor_exp should be + # (N,3,3) or broadcast from (N,1,1) or (N,3,1) + cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) + current_F_scaled = F_current * cell_factor_exp_mult + + F_new_scaled = current_F_scaled + dr_cell + state.cell_positions = F_new_scaled # track the scaled deformation gradient + F_new = F_new_scaled / (cell_factor_exp_mult + eps) # Division by (N,3,1) + new_cell = torch.bmm(state.reference_cell, F_new.transpose(-2, -1)) + state.cell = new_cell + + # 8. Force / stress refresh & new cell forces + results = model(state) + state.forces = results["forces"] + state.energy = results["energy"] + + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + state.stress = results["stress"] + volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + if torch.any(volumes <= 0): + bad_idx = torch.where(volumes <= 0)[0] + print( + f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " + f"{volumes[bad_idx].tolist()} at indices {bad_idx.tolist()} " + f"(is_frechet={is_frechet})" + ) + # volumes = torch.clamp(volumes, min=eps) # Optional: for stability + + virial = -volumes * (state.stress + state.pressure) + + if state.hydrostatic_strain: + diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) + virial = diag_mean.unsqueeze(-1) * torch.eye( + 3, device=device, dtype=dtype + ).unsqueeze(0).expand(n_batches, -1, -1) + if state.constant_volume: # Can be true even if hydrostatic_strain is false + diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) + virial = virial - diag_mean.unsqueeze(-1) * torch.eye( + 3, device=device, dtype=dtype + ).unsqueeze(0).expand(n_batches, -1, -1) + + if is_frechet: + assert isinstance(state, FrechetCellFIREState) + assert F_new is not None, ( + "F_new should be defined for Frechet cell force calculation" + ) + assert logm_F_new is not None, ( + "logm_F_new should be defined for Frechet cell force calculation" + ) + # Frechet cell force recalculation + ucf_cell_grad = torch.bmm( + virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) + ) + directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) + for idx, (mu, nu) in enumerate( + [(i_idx, j_idx) for i_idx in range(3) for j_idx in range(3)] + ): + directions[idx, mu, nu] = 1.0 + + new_cell_forces_log_space = torch.zeros_like(state.cell_forces) + for b_idx in range(n_batches): + # logm_F_new[b_idx] is the current point in log-space + expm_derivs = torch.stack( + [ + tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) + for direction in directions + ] + ) + forces_flat = torch.sum( + expm_derivs * ucf_cell_grad[b_idx].unsqueeze(0), dim=(1, 2) + ) + new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) + state.cell_forces = new_cell_forces_log_space / ( + state.cell_factor + eps + ) # cell_factor is (N,1,1) + else: # UnitCellFire + assert isinstance(state, UnitCellFireState) + # Unit cell force recalculation + state.cell_forces = virial / state.cell_factor # cell_factor is (N,1,1) + + return state diff --git a/torch_sim/unbatched/unbatched_optimizers.py b/torch_sim/unbatched/unbatched_optimizers.py index c8497c875..790b3c822 100644 --- a/torch_sim/unbatched/unbatched_optimizers.py +++ b/torch_sim/unbatched/unbatched_optimizers.py @@ -310,7 +310,7 @@ def fire_update( return fire_init, fire_update -def fire_ase( # noqa: PLR0915 +def fire_ase( # noqa: C901, PLR0915 *, model: torch.nn.Module, dt: float = 0.1, @@ -457,9 +457,8 @@ def fire_step(state: FIREState) -> FIREState: results = model(state) state.forces = results["forces"] state.energy = results["energy"] - power = torch.tensor( - -1.0, device=device, dtype=dtype - ) # Force uphill response + # Force uphill response + power = torch.tensor(-1.0, device=device, dtype=dtype) if power > 0: # Moving downhill # Mix velocity with normalized force f_norm = torch.sqrt(torch.sum(state.forces**2, dtype=dtype) + eps) @@ -490,8 +489,8 @@ def fire_step(state: FIREState) -> FIREState: state.positions = state.positions + dr # Update forces and energy at new positions results = model(state) - state.forces = results["forces"] - state.energy = results["energy"] + for key in ("forces", "energy"): + setattr(state, key, results[key]) return state return fire_init, fire_step @@ -586,10 +585,10 @@ def unit_cell_fire( # noqa: PLR0915, C901 eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters - params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min] - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [ - torch.as_tensor(p, device=device, dtype=dtype) for p in params - ] + dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = ( + torch.as_tensor(p, device=device, dtype=dtype) + for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min) + ) def fire_init( state: SimState | StateDict, @@ -895,10 +894,10 @@ def frechet_cell_fire( # noqa: PLR0915, C901 eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters - params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha] - dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha = [ - torch.as_tensor(p, device=device, dtype=dtype) for p in params - ] + dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha = ( + torch.as_tensor(p, device=device, dtype=dtype) + for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha) + ) def fire_init( state: SimState | StateDict, From 6b515dbbb176945aea9b57653bc0fdb731e0b5cd Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 14 May 2025 14:11:53 -0400 Subject: [PATCH 10/13] merge test_torchsim_frechet_cell_fire_vs_ase_mace.py with comparative ASE vs torch-sim test for Frechet Cell FIRE optimizer into test_optimizers.py - move `ase_mace_mpa` and `torchsim_mace_mpa` fixtures into `conftest.py` for wider reuse --- tests/conftest.py | 30 ++++ .../models/test_torchsim_vs_ase_fire_mace.py | 150 ------------------ tests/test_optimizers.py | 130 +++++++++++++-- 3 files changed, 146 insertions(+), 164 deletions(-) delete mode 100644 tests/models/test_torchsim_vs_ase_fire_mace.py diff --git a/tests/conftest.py b/tests/conftest.py index 468fc077a..e0a12b1e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,11 +5,14 @@ from ase import Atoms from ase.build import bulk, molecule from ase.spacegroup import crystal +from mace.calculators import MACECalculator +from mace.calculators.foundations_models import mace_mp from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.models.mace import MaceModel from torch_sim.state import SimState, concatenate_states from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel @@ -317,3 +320,30 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: compute_stress=True, cutoff=2.5 * 3.405, ) + + +MACE_CHECKPOINT_URL = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" + + +@pytest.fixture +def ase_mace_mpa() -> MACECalculator: + """Provides an ASE MACECalculator instance using mace_mp.""" + # Ensure dtype matches the one used in the torchsim fixture (float64) + return mace_mp(model=MACE_CHECKPOINT_URL, default_dtype="float64") + + +@pytest.fixture +def torchsim_mace_mpa() -> MaceModel: + """Provides a MACE MP model instance for the optimizer tests.""" + # Use float64 for potentially higher precision needed in optimization + dtype = getattr(torch, dtype_str := "float64") + raw_mace = mace_mp( + model=MACE_CHECKPOINT_URL, return_raw_model=True, default_dtype=dtype_str + ) + return MaceModel( + model=raw_mace, + device="cpu", + dtype=dtype, + compute_forces=True, + compute_stress=True, + ) diff --git a/tests/models/test_torchsim_vs_ase_fire_mace.py b/tests/models/test_torchsim_vs_ase_fire_mace.py deleted file mode 100644 index 6da47acd2..000000000 --- a/tests/models/test_torchsim_vs_ase_fire_mace.py +++ /dev/null @@ -1,150 +0,0 @@ -import copy -import random - -import numpy as np -import pytest -import torch -from ase.filters import FrechetCellFilter -from ase.optimize import FIRE - -import torch_sim as ts -from torch_sim.io import state_to_atoms -from torch_sim.optimizers import frechet_cell_fire -from torch_sim.state import SimState - - -pytest.importorskip("mace") -from mace.calculators import MACECalculator -from mace.calculators.foundations_models import mace_mp - -from torch_sim.models.mace import MaceModel - - -# Seed everything -torch.manual_seed(123) -rng = np.random.default_rng(123) -random.seed(123) - - -@pytest.fixture -def torchsim_mace_mp_model() -> MaceModel: - """Provides a MACE MP model instance for the optimizer tests.""" - # Use float64 for potentially higher precision needed in optimization - dtype = torch.float64 - mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - mace_model_raw = mace_mp( - model=mace_checkpoint_url, - return_raw_model=True, - default_dtype=str(dtype).split(".")[-1], - ) - return MaceModel( - model=mace_model_raw, - device="cpu", - dtype=dtype, - compute_forces=True, - compute_stress=True, - ) - - -@pytest.fixture -def ase_mace_mp_calculator() -> MACECalculator: - """Provides an ASE MACECalculator instance using mace_mp.""" - # Ensure dtype matches the one used in the torchsim fixture (float64) - mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" - dtype_str = str(torch.float64).split(".")[-1] - # Use the mace_mp function to get the ASE calculator directly - return mace_mp( - model=mace_checkpoint_url, - device="cpu", - default_dtype=dtype_str, - dispersion=False, - ) - - -def test_unit_cell_frechet_fire_vs_ase( - rattled_sio2_sim_state: SimState, - torchsim_mace_mp_model: MaceModel, - ase_mace_mp_calculator: MACECalculator, -) -> None: - """Compare Frechet Cell FIRE optimizer with - ASE's FIRE + FrechetCellFilter using MACE.""" - - # Use float64 for consistency with the MACE model fixture - dtype = torch.float64 - device = torchsim_mace_mp_model.device - - # --- Setup Initial State with float64 --- - initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device) - - # Ensure grads are enabled for both positions and cell - initial_state.positions = initial_state.positions.detach().requires_grad_( - requires_grad=True - ) - initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) - - n_steps = 20 - force_tol = 0.02 - - # --- Run Custom Frechet Cell FIRE with MACE model --- - - custom_state = ts.optimize( - system=initial_state, - model=torchsim_mace_mp_model, - optimizer=frechet_cell_fire, - max_steps=n_steps, - convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol), - ) - # --- Setup ASE System with native MACE calculator --- - ase_atoms = state_to_atoms(initial_state)[0] - # Use the native ASE MACE calculator fixture, not the removed TorchCalculator - ase_atoms.calc = ase_mace_mp_calculator - - # --- Run ASE FIRE with FrechetCellFilter --- - filtered_atoms = FrechetCellFilter(ase_atoms) - ase_opt = FIRE(filtered_atoms) - - ase_opt.run(fmax=force_tol, steps=n_steps) - - # --- Compare Results (between custom_state and ase_sim_state) --- - final_custom_energy = custom_state.energy.item() - final_custom_forces_max = torch.norm(custom_state.forces, dim=-1).max().item() - - final_custom_pos = custom_state.positions.detach() - final_custom_cell = custom_state.row_vector_cell.squeeze(0).detach() - - final_ase_energy = ase_atoms.get_potential_energy() - final_ase_forces = torch.tensor(ase_atoms.get_forces(), device=device, dtype=dtype) - - if final_ase_forces is not None: - final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item() - else: - final_ase_forces_max = float("nan") - - final_ase_pos = torch.tensor(ase_atoms.get_positions(), device=device, dtype=dtype) - final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype) - - # Compare energies (use looser tolerance for ML potential comparison) - assert abs(final_custom_energy - final_ase_energy) < 5e-2, ( - f"Final energies differ significantly after {n_steps} steps: " - f"Custom={final_custom_energy:.6f}, ASE_State={final_ase_energy:.6f}" - ) - - # Compare forces (report) - print( - f"Max Force ({n_steps} steps): Custom={final_custom_forces_max:.4f}, " - f"ASE_State={final_ase_forces_max:.4f}" - ) - - # Compare positions (looser tolerance) - pos_diff = torch.norm(final_custom_pos - final_ase_pos, dim=-1).mean().item() - assert pos_diff < 1.0, ( - f"Final positions differ significantly (avg displacement: {pos_diff:.4f})" - ) - - # Compare cell matrices (looser tolerance) - cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() - assert cell_diff < 1.0, ( - f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})" - f"\nCustom Cell:\n{final_custom_cell}" - f"\nASE_State Cell:\n{final_ase_cell}" - ) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index f4dce0610..8336fe5ca 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1,10 +1,17 @@ import copy +import functools from dataclasses import fields from typing import get_args import pytest import torch +from ase.filters import FrechetCellFilter +from ase.optimize import FIRE +from mace.calculators import MACECalculator +import torch_sim as ts +from torch_sim.io import state_to_atoms +from torch_sim.models.mace import MaceModel from torch_sim.optimizers import ( FireState, FrechetCellFIREState, @@ -35,10 +42,7 @@ def test_gradient_descent_optimization( initial_state = ar_supercell_sim_state # Initialize Gradient Descent optimizer - init_fn, update_fn = gradient_descent( - model=lj_model, - lr=0.01, - ) + init_fn, update_fn = gradient_descent(model=lj_model, lr=0.01) state = init_fn(ar_supercell_sim_state) @@ -295,7 +299,6 @@ def test_fire_vv_negative_power_branch( initial_dt_batch = state.dt.clone() initial_alpha_batch = state.alpha.clone() # Already alpha_start - initial_n_pos_batch = state.n_pos.clone() # Already 0 state_to_update = copy.deepcopy(state) updated_state = update_fn(state_to_update) @@ -315,15 +318,7 @@ def test_fire_vv_negative_power_branch( ) if not p_lt_0_branch_taken: - pytest.skip( - f"VV FIRE P<0 condition not reliably hit for batch 0. " - f"dt: {initial_dt_batch[0].item():.4f} -> {updated_state.dt[0].item():.4f} " - f"(expected factor {f_dec}). " - f"alpha: {initial_alpha_batch[0].item():.4f} -> " - f"{updated_state.alpha[0].item():.4f} (expected {alpha_start}). " - f"n_pos: {initial_n_pos_batch[0].item()} -> {updated_state.n_pos[0].item()} " - "(expected 0)." - ) + return # If P<0 branch was taken, velocities should be zeroed assert torch.allclose( @@ -920,3 +915,110 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: f"Energy for batch {step} doesn't match position only optimization: " f"batch={energy_unit_cell}, individual={individual_energies_fire[step]}" ) + + +def test_torchsim_frechet_cell_fire_vs_ase_mace( + rattled_sio2_sim_state: ts.state.SimState, + torchsim_mace_mpa: MaceModel, + ase_mace_mpa: MACECalculator, +) -> None: + """Compare torch-sim's Frechet Cell FIRE optimizer with ASE's FIRE + FrechetCellFilter + using MACE-MPA-0. + + This test ensures that the custom Frechet Cell FIRE implementation behaves comparably + to the established ASE equivalent when using a MACE force field. + It checks for consistency in final energies, forces, positions, and cell parameters. + """ + # Use float64 for consistency with the MACE model fixture and for precision + dtype = torch.float64 + device = torchsim_mace_mpa.device # Use device from the model + + # --- Setup Initial State with float64 --- + # Deepcopy to avoid modifying the fixture state for other tests + initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device) + + # Ensure grads are enabled for both positions and cell for optimization + initial_state.positions = initial_state.positions.detach().requires_grad_( + requires_grad=True + ) + initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) + + n_steps = 20 # Number of optimization steps + force_tol = 0.02 # Convergence criterion for forces + + # --- Run torch-sim Frechet Cell FIRE with MACE model --- + # Use functools.partial to set md_flavor for the frechet_cell_fire optimizer + torch_sim_optimizer = functools.partial(frechet_cell_fire, md_flavor="ase_fire") + + custom_opt_state = ts.optimize( + system=initial_state, + model=torchsim_mace_mpa, + optimizer=torch_sim_optimizer, + max_steps=n_steps, + convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol), + ) + + # --- Setup ASE System with native MACE calculator --- + # Convert initial SimState to ASE Atoms object + ase_atoms = state_to_atoms(initial_state)[0] # state_to_atoms returns a list + ase_atoms.calc = ase_mace_mpa # Assign the MACE calculator + + # --- Run ASE FIRE with FrechetCellFilter --- + # Apply FrechetCellFilter for cell optimization + filtered_ase_atoms = FrechetCellFilter(ase_atoms) + ase_optimizer = FIRE(filtered_ase_atoms) + + # Run ASE optimization + ase_optimizer.run(fmax=force_tol, steps=n_steps) + + # --- Compare Results --- + final_custom_energy = custom_opt_state.energy.item() + final_custom_forces_max = torch.norm(custom_opt_state.forces, dim=-1).max().item() + final_custom_positions = custom_opt_state.positions.detach() + # Ensure cell is in row vector format and squeezed for comparison + final_custom_cell = custom_opt_state.row_vector_cell.squeeze(0).detach() + + final_ase_energy = ase_atoms.get_potential_energy() + ase_forces_raw = ase_atoms.get_forces() + if ase_forces_raw is not None: + final_ase_forces = torch.tensor(ase_forces_raw, device=device, dtype=dtype) + final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item() + else: + # Should not happen if calculator ran and produced forces + final_ase_forces_max = float("nan") + + final_ase_positions = torch.tensor( + ase_atoms.get_positions(), device=device, dtype=dtype + ) + final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype) + + # Compare energies (looser tolerance for ML potentials due to potential minor + # numerical differences) + energy_diff = abs(final_custom_energy - final_ase_energy) + assert energy_diff < 5e-2, ( + f"Final energies differ significantly after {n_steps} steps: " + f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " + f"Diff={energy_diff:.2e}" + ) + + # Report forces for diagnostics + print( + f"Max Force ({n_steps} steps): torch-sim={final_custom_forces_max:.4f}, " + f"ASE={final_ase_forces_max:.4f}" + ) + + # Compare positions (average displacement, looser tolerance) + avg_displacement = ( + torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() + ) + assert avg_displacement < 1.0, ( + f"Final positions differ significantly (avg displacement: {avg_displacement:.4f})" + ) + + # Compare cell matrices (Frobenius norm, looser tolerance) + cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() + assert cell_diff < 1.0, ( + f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})" + f"\nTorch-sim Cell:\n{final_custom_cell}" + f"\nASE Cell:\n{final_ase_cell}" + ) From 2fa6a7e30bb2c3aed02fb1b85ce3f524d982b8c4 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 14 May 2025 14:19:55 -0400 Subject: [PATCH 11/13] redirect MACE_CHECKPOINT_URL to mace_agnesi_small for faster tests --- tests/conftest.py | 18 +++++++++++++----- tests/models/test_mace.py | 6 ++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e0a12b1e5..a323ed5b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,10 @@ -from typing import Any +from typing import TYPE_CHECKING, Any, Final import pytest import torch from ase import Atoms from ase.build import bulk, molecule from ase.spacegroup import crystal -from mace.calculators import MACECalculator -from mace.calculators.foundations_models import mace_mp from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure @@ -17,6 +15,10 @@ from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel +if TYPE_CHECKING: + from mace.calculators import MACECalculator + + @pytest.fixture def device() -> torch.device: return torch.device("cpu") @@ -322,12 +324,16 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: ) -MACE_CHECKPOINT_URL = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" +MACE_CHECKPOINT_URL: Final[str] = ( + "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model" +) @pytest.fixture -def ase_mace_mpa() -> MACECalculator: +def ase_mace_mpa() -> "MACECalculator": """Provides an ASE MACECalculator instance using mace_mp.""" + from mace.calculators.foundations_models import mace_mp + # Ensure dtype matches the one used in the torchsim fixture (float64) return mace_mp(model=MACE_CHECKPOINT_URL, default_dtype="float64") @@ -335,6 +341,8 @@ def ase_mace_mpa() -> MACECalculator: @pytest.fixture def torchsim_mace_mpa() -> MaceModel: """Provides a MACE MP model instance for the optimizer tests.""" + from mace.calculators.foundations_models import mace_mp + # Use float64 for potentially higher precision needed in optimization dtype = getattr(torch, dtype_str := "float64") raw_mace = mace_mp( diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index ecd3a2395..54c8dda72 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -117,13 +117,11 @@ def torchsim_mace_off_model(device: torch.device, dtype: torch.dtype) -> MaceMod test_name="mace_off", model_fixture_name="torchsim_mace_off_model", calculator_fixture_name="ase_mace_off_calculator", - sim_state_names=[ - "benzene_sim_state", - ], + sim_state_names=["benzene_sim_state"], ) test_mace_off_model_outputs = make_validate_model_outputs_test( - model_fixture_name="torchsim_mace_model", + model_fixture_name="torchsim_mace_model" ) From 914cbe3680ed36e11a3bd6757a643258c5fde36c Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 14 May 2025 14:29:27 -0400 Subject: [PATCH 12/13] on 2nd thought, keep test_torchsim_frechet_cell_fire_vs_ase_mace in a separate file (thanks @CompRhys) --- .github/workflows/test.yml | 7 +- tests/test_optimizers.py | 114 ------------------------------ tests/test_optimizers_vs_ase.py | 119 ++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 119 deletions(-) create mode 100644 tests/test_optimizers_vs_ase.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 58d0eee19..6861b35fa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -48,7 +48,7 @@ jobs: --ignore=tests/models/test_sevennet.py \ --ignore=tests/models/test_mattersim.py \ --ignore=tests/models/test_metatensor.py \ - --ignore=tests/models/test_torchsim_vs_ase_fire_mace.py \ + --ignore=tests/test_optimizers_vs_ase.py \ - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 @@ -69,10 +69,7 @@ jobs: - { name: graphpes, test_path: "tests/models/test_graphpes.py" } - { name: mace, test_path: "tests/models/test_mace.py" } - { name: mace, test_path: "tests/test_elastic.py" } - - { - name: mace, - test_path: "tests/models/test_torchsim_vs_ase_fire_mace.py", - } + - { name: mace, test_path: "tests/test_optimizers_vs_ase.py" } - { name: mattersim, test_path: "tests/models/test_mattersim.py" } - { name: metatensor, test_path: "tests/models/test_metatensor.py" } - { name: orb, test_path: "tests/models/test_orb.py" } diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 8336fe5ca..443c00590 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1,17 +1,10 @@ import copy -import functools from dataclasses import fields from typing import get_args import pytest import torch -from ase.filters import FrechetCellFilter -from ase.optimize import FIRE -from mace.calculators import MACECalculator -import torch_sim as ts -from torch_sim.io import state_to_atoms -from torch_sim.models.mace import MaceModel from torch_sim.optimizers import ( FireState, FrechetCellFIREState, @@ -915,110 +908,3 @@ def energy_converged(current_energy: float, prev_energy: float) -> bool: f"Energy for batch {step} doesn't match position only optimization: " f"batch={energy_unit_cell}, individual={individual_energies_fire[step]}" ) - - -def test_torchsim_frechet_cell_fire_vs_ase_mace( - rattled_sio2_sim_state: ts.state.SimState, - torchsim_mace_mpa: MaceModel, - ase_mace_mpa: MACECalculator, -) -> None: - """Compare torch-sim's Frechet Cell FIRE optimizer with ASE's FIRE + FrechetCellFilter - using MACE-MPA-0. - - This test ensures that the custom Frechet Cell FIRE implementation behaves comparably - to the established ASE equivalent when using a MACE force field. - It checks for consistency in final energies, forces, positions, and cell parameters. - """ - # Use float64 for consistency with the MACE model fixture and for precision - dtype = torch.float64 - device = torchsim_mace_mpa.device # Use device from the model - - # --- Setup Initial State with float64 --- - # Deepcopy to avoid modifying the fixture state for other tests - initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device) - - # Ensure grads are enabled for both positions and cell for optimization - initial_state.positions = initial_state.positions.detach().requires_grad_( - requires_grad=True - ) - initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) - - n_steps = 20 # Number of optimization steps - force_tol = 0.02 # Convergence criterion for forces - - # --- Run torch-sim Frechet Cell FIRE with MACE model --- - # Use functools.partial to set md_flavor for the frechet_cell_fire optimizer - torch_sim_optimizer = functools.partial(frechet_cell_fire, md_flavor="ase_fire") - - custom_opt_state = ts.optimize( - system=initial_state, - model=torchsim_mace_mpa, - optimizer=torch_sim_optimizer, - max_steps=n_steps, - convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol), - ) - - # --- Setup ASE System with native MACE calculator --- - # Convert initial SimState to ASE Atoms object - ase_atoms = state_to_atoms(initial_state)[0] # state_to_atoms returns a list - ase_atoms.calc = ase_mace_mpa # Assign the MACE calculator - - # --- Run ASE FIRE with FrechetCellFilter --- - # Apply FrechetCellFilter for cell optimization - filtered_ase_atoms = FrechetCellFilter(ase_atoms) - ase_optimizer = FIRE(filtered_ase_atoms) - - # Run ASE optimization - ase_optimizer.run(fmax=force_tol, steps=n_steps) - - # --- Compare Results --- - final_custom_energy = custom_opt_state.energy.item() - final_custom_forces_max = torch.norm(custom_opt_state.forces, dim=-1).max().item() - final_custom_positions = custom_opt_state.positions.detach() - # Ensure cell is in row vector format and squeezed for comparison - final_custom_cell = custom_opt_state.row_vector_cell.squeeze(0).detach() - - final_ase_energy = ase_atoms.get_potential_energy() - ase_forces_raw = ase_atoms.get_forces() - if ase_forces_raw is not None: - final_ase_forces = torch.tensor(ase_forces_raw, device=device, dtype=dtype) - final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item() - else: - # Should not happen if calculator ran and produced forces - final_ase_forces_max = float("nan") - - final_ase_positions = torch.tensor( - ase_atoms.get_positions(), device=device, dtype=dtype - ) - final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype) - - # Compare energies (looser tolerance for ML potentials due to potential minor - # numerical differences) - energy_diff = abs(final_custom_energy - final_ase_energy) - assert energy_diff < 5e-2, ( - f"Final energies differ significantly after {n_steps} steps: " - f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " - f"Diff={energy_diff:.2e}" - ) - - # Report forces for diagnostics - print( - f"Max Force ({n_steps} steps): torch-sim={final_custom_forces_max:.4f}, " - f"ASE={final_ase_forces_max:.4f}" - ) - - # Compare positions (average displacement, looser tolerance) - avg_displacement = ( - torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() - ) - assert avg_displacement < 1.0, ( - f"Final positions differ significantly (avg displacement: {avg_displacement:.4f})" - ) - - # Compare cell matrices (Frobenius norm, looser tolerance) - cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() - assert cell_diff < 1.0, ( - f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})" - f"\nTorch-sim Cell:\n{final_custom_cell}" - f"\nASE Cell:\n{final_ase_cell}" - ) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py new file mode 100644 index 000000000..ade26c478 --- /dev/null +++ b/tests/test_optimizers_vs_ase.py @@ -0,0 +1,119 @@ +import copy +import functools + +import torch +from ase.filters import FrechetCellFilter +from ase.optimize import FIRE +from mace.calculators import MACECalculator + +import torch_sim as ts +from torch_sim.io import state_to_atoms +from torch_sim.models.mace import MaceModel +from torch_sim.optimizers import frechet_cell_fire + + +def test_torchsim_frechet_cell_fire_vs_ase_mace( + rattled_sio2_sim_state: ts.state.SimState, + torchsim_mace_mpa: MaceModel, + ase_mace_mpa: MACECalculator, +) -> None: + """Compare torch-sim's Frechet Cell FIRE optimizer with ASE's FIRE + FrechetCellFilter + using MACE-MPA-0. + + This test ensures that the custom Frechet Cell FIRE implementation behaves comparably + to the established ASE equivalent when using a MACE force field. + It checks for consistency in final energies, forces, positions, and cell parameters. + """ + # Use float64 for consistency with the MACE model fixture and for precision + dtype = torch.float64 + device = torchsim_mace_mpa.device # Use device from the model + + # --- Setup Initial State with float64 --- + # Deepcopy to avoid modifying the fixture state for other tests + initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device) + + # Ensure grads are enabled for both positions and cell for optimization + initial_state.positions = initial_state.positions.detach().requires_grad_( + requires_grad=True + ) + initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) + + n_steps = 20 # Number of optimization steps + force_tol = 0.02 # Convergence criterion for forces + + # --- Run torch-sim Frechet Cell FIRE with MACE model --- + # Use functools.partial to set md_flavor for the frechet_cell_fire optimizer + torch_sim_optimizer = functools.partial(frechet_cell_fire, md_flavor="ase_fire") + + custom_opt_state = ts.optimize( + system=initial_state, + model=torchsim_mace_mpa, + optimizer=torch_sim_optimizer, + max_steps=n_steps, + convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol), + ) + + # --- Setup ASE System with native MACE calculator --- + # Convert initial SimState to ASE Atoms object + ase_atoms = state_to_atoms(initial_state)[0] # state_to_atoms returns a list + ase_atoms.calc = ase_mace_mpa # Assign the MACE calculator + + # --- Run ASE FIRE with FrechetCellFilter --- + # Apply FrechetCellFilter for cell optimization + filtered_ase_atoms = FrechetCellFilter(ase_atoms) + ase_optimizer = FIRE(filtered_ase_atoms) + + # Run ASE optimization + ase_optimizer.run(fmax=force_tol, steps=n_steps) + + # --- Compare Results --- + final_custom_energy = custom_opt_state.energy.item() + final_custom_forces_max = torch.norm(custom_opt_state.forces, dim=-1).max().item() + final_custom_positions = custom_opt_state.positions.detach() + # Ensure cell is in row vector format and squeezed for comparison + final_custom_cell = custom_opt_state.row_vector_cell.squeeze(0).detach() + + final_ase_energy = ase_atoms.get_potential_energy() + ase_forces_raw = ase_atoms.get_forces() + if ase_forces_raw is not None: + final_ase_forces = torch.tensor(ase_forces_raw, device=device, dtype=dtype) + final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item() + else: + # Should not happen if calculator ran and produced forces + final_ase_forces_max = float("nan") + + final_ase_positions = torch.tensor( + ase_atoms.get_positions(), device=device, dtype=dtype + ) + final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype) + + # Compare energies (looser tolerance for ML potentials due to potential minor + # numerical differences) + energy_diff = abs(final_custom_energy - final_ase_energy) + assert energy_diff < 5e-2, ( + f"Final energies differ significantly after {n_steps} steps: " + f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " + f"Diff={energy_diff:.2e}" + ) + + # Report forces for diagnostics + print( + f"Max Force ({n_steps} steps): torch-sim={final_custom_forces_max:.4f}, " + f"ASE={final_ase_forces_max:.4f}" + ) + + # Compare positions (average displacement, looser tolerance) + avg_displacement = ( + torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() + ) + assert avg_displacement < 1.0, ( + f"Final positions differ significantly (avg displacement: {avg_displacement:.4f})" + ) + + # Compare cell matrices (Frobenius norm, looser tolerance) + cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() + assert cell_diff < 1.0, ( + f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})" + f"\nTorch-sim Cell:\n{final_custom_cell}" + f"\nASE Cell:\n{final_ase_cell}" + ) From 03bdb809ae413b6044fa7cf447dca9a1a7d3815a Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 14 May 2025 14:49:36 -0400 Subject: [PATCH 13/13] define MaceUrls StrEnum to avoid breaking tests when "small" checkpoints get redirected in mace-torch --- tests/conftest.py | 17 +++++++++-------- tests/models/test_mace.py | 9 +++++---- tests/unbatched/test_unbatched_mace.py | 7 ++++--- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a323ed5b1..8fc6063c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any, Final +from enum import StrEnum +from typing import TYPE_CHECKING, Any import pytest import torch @@ -19,6 +20,11 @@ from mace.calculators import MACECalculator +class MaceUrls(StrEnum): + mace_small = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model" + mace_off_small = "https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true" + + @pytest.fixture def device() -> torch.device: return torch.device("cpu") @@ -324,18 +330,13 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: ) -MACE_CHECKPOINT_URL: Final[str] = ( - "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model" -) - - @pytest.fixture def ase_mace_mpa() -> "MACECalculator": """Provides an ASE MACECalculator instance using mace_mp.""" from mace.calculators.foundations_models import mace_mp # Ensure dtype matches the one used in the torchsim fixture (float64) - return mace_mp(model=MACE_CHECKPOINT_URL, default_dtype="float64") + return mace_mp(model=MaceUrls.mace_small, default_dtype="float64") @pytest.fixture @@ -346,7 +347,7 @@ def torchsim_mace_mpa() -> MaceModel: # Use float64 for potentially higher precision needed in optimization dtype = getattr(torch, dtype_str := "float64") raw_mace = mace_mp( - model=MACE_CHECKPOINT_URL, return_raw_model=True, default_dtype=dtype_str + model=MaceUrls.mace_small, return_raw_model=True, default_dtype=dtype_str ) return MaceModel( model=raw_mace, diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 54c8dda72..c0c110d5c 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -3,6 +3,7 @@ from ase.atoms import Atoms import torch_sim as ts +from tests.conftest import MaceUrls from tests.models.conftest import ( consistency_test_simstate_fixtures, make_model_calculator_consistency_test, @@ -19,8 +20,8 @@ pytest.skip("MACE not installed", allow_module_level=True) -mace_model = mace_mp(model="small", return_raw_model=True) -mace_off_model = mace_off(model="small", return_raw_model=True) +mace_model = mace_mp(model=MaceUrls.mace_small, return_raw_model=True) +mace_off_model = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True) @pytest.fixture @@ -32,7 +33,7 @@ def dtype() -> torch.dtype: @pytest.fixture def ase_mace_calculator() -> MACECalculator: return mace_mp( - model="small", + model=MaceUrls.mace_small, device="cpu", default_dtype="float32", dispersion=False, @@ -96,7 +97,7 @@ def benzene_system( @pytest.fixture def ase_mace_off_calculator() -> MACECalculator: return mace_off( - model="small", + model=MaceUrls.mace_off_small, device="cpu", default_dtype="float32", dispersion=False, diff --git a/tests/unbatched/test_unbatched_mace.py b/tests/unbatched/test_unbatched_mace.py index 8b08df734..0975f0a4f 100644 --- a/tests/unbatched/test_unbatched_mace.py +++ b/tests/unbatched/test_unbatched_mace.py @@ -3,6 +3,7 @@ from ase.atoms import Atoms import torch_sim as ts +from tests.conftest import MaceUrls from tests.unbatched.conftest import make_unbatched_model_calculator_consistency_test @@ -15,8 +16,8 @@ pytest.skip("MACE not installed", allow_module_level=True) -mace_model = mace_mp(model="small", return_raw_model=True) -mace_off_model = mace_off(model="small", return_raw_model=True) +mace_model = mace_mp(model=MaceUrls.mace_small, return_raw_model=True) +mace_off_model = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True) @pytest.fixture @@ -28,7 +29,7 @@ def dtype() -> torch.dtype: @pytest.fixture def ase_mace_calculator() -> MACECalculator: return mace_mp( - model="small", + model=MaceUrls.mace_small, device="cpu", default_dtype="float32", dispersion=False,