diff --git a/examples/scripts/2_structural_optimization.py b/examples/scripts/2_structural_optimization.py index 423504994..288494ab1 100644 --- a/examples/scripts/2_structural_optimization.py +++ b/examples/scripts/2_structural_optimization.py @@ -386,6 +386,72 @@ print(f"Initial pressure: {initial_pressure} GPa") print(f"Final pressure: {final_pressure} GPa") +# ============================================================================ +# SECTION 7: Batched MACE L-BFGS +# ============================================================================ +print("\n" + "=" * 70) +print("SECTION 7: Batched MACE L-BFGS") +print("=" * 70) + +# Recreate structures with perturbations +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +results = model(state) +state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0) + +print("\nRunning L-BFGS:") +for step in range(N_steps): + if step % 20 == 0: + print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") + state = ts.lbfgs_step(state=state, model=model, max_history=100) + +print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") +print(f"Final energies: {[energy.item() for energy in state.energy]} eV") + + +# ============================================================================ +# SECTION 8: Batched MACE BFGS +# ============================================================================ +print("\n" + "=" * 70) +print("SECTION 8: Batched MACE BFGS") +print("=" * 70) + +# Recreate structures with perturbations +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +results = model(state) +state = ts.bfgs_init(state=state, model=model, alpha=70.0) + +print("\nRunning BFGS:") +for step in range(N_steps): + if step % 20 == 0: + print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") + state = ts.bfgs_step(state=state, model=model) + +print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") +print(f"Final energies: {[energy.item() for energy in state.energy]} eV") + + print("\n" + "=" * 70) print("Structural optimization examples completed!") print("=" * 70) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7ec870d3c..c3e9aea23 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -605,3 +605,179 @@ def test_in_flight_max_iterations( # Verify iteration_count tracking for idx in range(len(states)): assert batcher.iteration_count[idx] == max_iterations + + +@pytest.mark.parametrize( + "num_steps_per_batch", + [ + 5, # At 5 steps, not every state will converge before the next batch. + 10, # At 10 steps, all states will converge before the next batch + ], +) +def test_in_flight_with_bfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + num_steps_per_batch: int, +) -> None: + """Test InFlightAutoBatcher with BFGS optimizer.""" + si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_bfgs_state = ts.bfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5 + bfgs_states = [state.clone() for state in bfgs_states] + for state in bfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = InFlightAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=6000, + ) + batcher.load_states(bfgs_states) + + def convergence_fn(state: ts.BFGSState) -> torch.Tensor: + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=torch.float64 + ) + max_forces = state.forces.norm(dim=1) + system_wise_max_force = system_wise_max_force.scatter_reduce( + dim=0, index=state.system_idx, src=max_forces, reduce="amax" + ) + return system_wise_max_force < 5e-1 + + all_completed_states, convergence_tensor = [], None + while True: + state, completed_states = batcher.next_batch(state, convergence_tensor) + + all_completed_states.extend(completed_states) + if state is None: + break + + for _ in range(num_steps_per_batch): + state = ts.bfgs_step(state=state, model=lj_model) + convergence_tensor = convergence_fn(state) + + assert len(all_completed_states) == len(bfgs_states) + + +def test_binning_auto_batcher_with_bfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with BFGS optimizer.""" + si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_bfgs_state = ts.bfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5 + bfgs_states = [state.clone() for state in bfgs_states] + for state in bfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = BinningAutoBatcher( + model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000 + ) + batcher.load_states(bfgs_states) + + all_finished_states: list[ts.SimState] = [] + total_batches = 0 + for batch, _ in batcher: + total_batches += 1 # noqa: SIM113 + for _ in range(5): + batch = ts.bfgs_step(state=batch, model=lj_model) + all_finished_states.extend(batch.split()) + + assert len(all_finished_states) == len(bfgs_states) + + +@pytest.mark.parametrize( + "num_steps_per_batch", + [ + 5, # At 5 steps, not every state will converge before the next batch. + 10, # At 10 steps, all states will converge before the next batch + ], +) +def test_in_flight_with_lbfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + num_steps_per_batch: int, +) -> None: + """Test InFlightAutoBatcher with L-BFGS optimizer.""" + si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_lbfgs_state = ts.lbfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5 + lbfgs_states = [state.clone() for state in lbfgs_states] + for state in lbfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = InFlightAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=6000, + ) + batcher.load_states(lbfgs_states) + + def convergence_fn(state: ts.LBFGSState) -> torch.Tensor: + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=torch.float64 + ) + max_forces = state.forces.norm(dim=1) + system_wise_max_force = system_wise_max_force.scatter_reduce( + dim=0, index=state.system_idx, src=max_forces, reduce="amax" + ) + return system_wise_max_force < 5e-1 + + all_completed_states, convergence_tensor = [], None + while True: + state, completed_states = batcher.next_batch(state, convergence_tensor) + + all_completed_states.extend(completed_states) + if state is None: + break + + for _ in range(num_steps_per_batch): + state = ts.lbfgs_step(state=state, model=lj_model) + convergence_tensor = convergence_fn(state) + + assert len(all_completed_states) == len(lbfgs_states) + + +def test_binning_auto_batcher_with_lbfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with L-BFGS optimizer.""" + si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_lbfgs_state = ts.lbfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5 + lbfgs_states = [state.clone() for state in lbfgs_states] + for state in lbfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = BinningAutoBatcher( + model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000 + ) + batcher.load_states(lbfgs_states) + + all_finished_states: list[ts.SimState] = [] + total_batches = 0 + for batch, _ in batcher: + total_batches += 1 # noqa: SIM113 + for _ in range(5): + batch = ts.lbfgs_step(state=batch, model=lj_model) + all_finished_states.extend(batch.split()) + + assert len(all_finished_states) == len(lbfgs_states) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 3bbfb44ed..a98833e48 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -392,6 +392,111 @@ def test_fix_com_fire_optimization( assert torch.allclose(final_com, initial_com, atol=1e-4) +@pytest.mark.parametrize("optimizer", ["bfgs", "lbfgs"]) +def test_fix_atoms_bfgs_lbfgs_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + optimizer: str, +) -> None: + """Test FixAtoms constraint in BFGS/LBFGS optimization.""" + # Create a fresh copy with random displacement + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + indices = torch.tensor([0, 2], dtype=torch.long) + current_sim_state.constraints = [FixAtoms(atom_idx=indices)] + + # Initialize optimizer + if optimizer == "bfgs": + state = ts.bfgs_init(current_sim_state, lj_model) + step_fn = ts.bfgs_step + else: + state = ts.lbfgs_init(current_sim_state, lj_model) + step_fn = ts.lbfgs_step + + initial_position = state.positions[indices].clone() + + # Run optimization + energies = [1000, state.energy.item()] + max_steps = 500 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = step_fn(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + final_position = state.positions[indices] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + + +@pytest.mark.parametrize("optimizer", ["bfgs", "lbfgs"]) +def test_fix_com_bfgs_lbfgs_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + optimizer: str, +) -> None: + """Test FixCom constraint in BFGS/LBFGS optimization.""" + # Create a fresh copy with random displacement + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + current_sim_state.constraints = [FixCom([0])] + + # Initialize optimizer + if optimizer == "bfgs": + state = ts.bfgs_init(current_sim_state, lj_model) + step_fn = ts.bfgs_step + else: + state = ts.lbfgs_init(current_sim_state, lj_model) + step_fn = ts.lbfgs_step + + initial_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + # Run optimization + energies = [1000, state.energy.item()] + max_steps = 500 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = step_fn(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + final_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + assert torch.allclose(final_com, initial_com, atol=1e-4) + + def test_fix_atoms_validation() -> None: """Test FixAtoms construction and validation.""" # Boolean mask conversion @@ -587,6 +692,41 @@ def test_cell_optimization_with_constraints( assert len(state.constraints) > 0 +@pytest.mark.parametrize( + ("cell_filter", "optimizer"), + [ + (ts.CellFilter.unit, "bfgs"), + (ts.CellFilter.frechet, "bfgs"), + (ts.CellFilter.unit, "lbfgs"), + (ts.CellFilter.frechet, "lbfgs"), + ], +) +def test_cell_optimization_with_constraints_bfgs_lbfgs( + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + cell_filter: str, + optimizer: str, +) -> None: + """Test cell filters work with constraints for BFGS/LBFGS.""" + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.05 + ) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1])] + + if optimizer == "bfgs": + state = ts.bfgs_init(ar_supercell_sim_state, lj_model, cell_filter=cell_filter) + step_fn = ts.bfgs_step + else: + state = ts.lbfgs_init(ar_supercell_sim_state, lj_model, cell_filter=cell_filter) + step_fn = ts.lbfgs_step + + for _ in range(50): + state = step_fn(state, lj_model) + if state.forces.abs().max() < 0.05: + break + assert len(state.constraints) > 0 + + def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: """Test system-specific constraints in batched states.""" s1, s2 = ar_double_sim_state.split() diff --git a/tests/test_optimizer_states.py b/tests/test_optimizer_states.py index eee408756..2052974e4 100644 --- a/tests/test_optimizer_states.py +++ b/tests/test_optimizer_states.py @@ -3,7 +3,7 @@ import pytest import torch -from torch_sim.optimizers.state import FireState, OptimState +from torch_sim.optimizers.state import BFGSState, FireState, LBFGSState, OptimState from torch_sim.state import SimState @@ -57,3 +57,53 @@ def test_fire_state_custom_values(sim_state: SimState, optim_data: dict) -> None assert torch.equal(state.dt, fire_data["dt"]) assert torch.equal(state.alpha, fire_data["alpha"]) assert torch.equal(state.n_pos, fire_data["n_pos"]) + + +def test_bfgs_state_custom_values(sim_state: SimState, optim_data: dict) -> None: + """Test BFGSState with custom values.""" + bfgs_data = { + "hessian": torch.eye(6, dtype=torch.float64).unsqueeze(0), # [1, 6, 6] + "prev_forces": optim_data["forces"].clone(), + "prev_positions": sim_state.positions.clone(), + "alpha": torch.tensor([70.0], dtype=torch.float64), + "max_step": torch.tensor([0.2], dtype=torch.float64), + "n_iter": torch.tensor([0], dtype=torch.int32), + "atom_idx_in_system": torch.arange(2, dtype=torch.int64), + "max_atoms": torch.tensor([2], dtype=torch.int64), + } + + state = BFGSState(**sim_state.attributes, **optim_data, **bfgs_data) + + assert torch.equal(state.hessian, bfgs_data["hessian"]) + assert torch.equal(state.prev_forces, bfgs_data["prev_forces"]) + assert torch.equal(state.prev_positions, bfgs_data["prev_positions"]) + assert torch.equal(state.alpha, bfgs_data["alpha"]) + assert torch.equal(state.max_step, bfgs_data["max_step"]) + assert torch.equal(state.n_iter, bfgs_data["n_iter"]) + assert torch.equal(state.atom_idx_in_system, bfgs_data["atom_idx_in_system"]) + assert torch.equal(state.max_atoms, bfgs_data["max_atoms"]) + + +def test_lbfgs_state_custom_values(sim_state: SimState, optim_data: dict) -> None: + """Test LBFGSState with custom values.""" + lbfgs_data = { + "prev_forces": optim_data["forces"].clone(), + "prev_positions": sim_state.positions.clone(), + "s_history": torch.zeros((1, 0, 2, 3), dtype=torch.float64), # [S, H, M, 3] + "y_history": torch.zeros((1, 0, 2, 3), dtype=torch.float64), + "step_size": torch.tensor([1.0], dtype=torch.float64), + "alpha": torch.tensor([70.0], dtype=torch.float64), + "n_iter": torch.tensor([0], dtype=torch.int32), + "max_atoms": torch.tensor([2], dtype=torch.int64), + } + + state = LBFGSState(**sim_state.attributes, **optim_data, **lbfgs_data) + + assert torch.equal(state.prev_forces, lbfgs_data["prev_forces"]) + assert torch.equal(state.prev_positions, lbfgs_data["prev_positions"]) + assert torch.equal(state.s_history, lbfgs_data["s_history"]) + assert torch.equal(state.y_history, lbfgs_data["y_history"]) + assert torch.equal(state.step_size, lbfgs_data["step_size"]) + assert torch.equal(state.alpha, lbfgs_data["alpha"]) + assert torch.equal(state.n_iter, lbfgs_data["n_iter"]) + assert torch.equal(state.max_atoms, lbfgs_data["max_atoms"]) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 0acb9835a..f5071d1f9 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -9,7 +9,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers import FireFlavor, FireState, OptimState +from torch_sim.optimizers import BFGSState, FireFlavor, FireState, LBFGSState, OptimState from torch_sim.state import SimState @@ -161,11 +161,409 @@ def test_fire_optimization( ) +def test_bfgs_optimization( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test that the BFGS optimizer actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + + # Initialize BFGS optimizer + state = ts.bfgs_init(current_sim_state, lj_model) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.bfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, f"BFGS optimization did not converge in {max_steps=}" + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"BFGS optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert max_force < 0.3, f"Forces should be small after optimization, got {max_force=}" + + assert not torch.allclose(state.positions, initial_state_positions), ( + "BFGS positions should have changed after optimization." + ) + + +def test_lbfgs_optimization( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test that the L-BFGS optimizer actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + + # Initialize L-BFGS optimizer + state = ts.lbfgs_init(current_sim_state, lj_model) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.lbfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, ( + f"L-BFGS optimization did not converge in {max_steps=}" + ) + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"L-BFGS optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert max_force < 0.3, f"Forces should be small after optimization, got {max_force=}" + + assert not torch.allclose(state.positions, initial_state_positions), ( + "L-BFGS positions should have changed after optimization." + ) + + +@pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) +def test_bfgs_cell_optimization( + ar_supercell_sim_state: SimState, + lj_model: ModelInterface, + cell_filter: ts.CellFilter, +) -> None: + """Test that BFGS with cell filter actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + current_cell = ( + ar_supercell_sim_state.cell.clone() + + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=current_cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + initial_state_cell = current_sim_state.cell.clone() + + # Initialize BFGS optimizer with cell filter + state = ts.bfgs_init( + state=current_sim_state, + model=lj_model, + cell_filter=cell_filter, + ) + + # Run optimization + energies = [1000.0, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.bfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, ( + f"BFGS {cell_filter.name} optimization did not converge in {max_steps=}" + ) + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"BFGS {cell_filter.name} optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + pressure = torch.trace(state.stress.squeeze(0)) / 3.0 + + assert torch.abs(pressure) < 0.05, ( + f"Pressure should be small after {cell_filter.name} optimization, got {pressure=}" + ) + assert max_force < 0.3, ( + f"Forces should be small after {cell_filter.name} optimization, got {max_force=}" + ) + + assert not torch.allclose(state.positions, initial_state_positions, atol=1e-5), ( + f"BFGS {cell_filter.name} positions should have changed after optimization." + ) + assert not torch.allclose(state.cell, initial_state_cell, atol=1e-5), ( + f"BFGS {cell_filter.name} cell should have changed after optimization." + ) + + +def test_unit_cell_bfgs_multi_batch( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test BFGS optimization with multiple batches.""" + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + multi_state = ts.concatenate_states( + [ar_supercell_sim_state_1, ar_supercell_sim_state_2], + device=ar_supercell_sim_state.device, + ) + + # Initialize BFGS optimizer with unit cell filter + state = ts.bfgs_init( + state=multi_state, model=lj_model, cell_filter=ts.CellFilter.unit + ) + initial_state = copy.deepcopy(state) + + # Run optimization + prev_energy = torch.ones(2, device=state.device, dtype=state.energy.dtype) * 1000 + current_energy = initial_state.energy + step = 0 + while not torch.allclose(current_energy, prev_energy, atol=1e-9): + prev_energy = current_energy + state = ts.bfgs_step(state=state, model=lj_model) + current_energy = state.energy + + step += 1 + if step > 500: + raise ValueError("BFGS optimization did not converge") + + # Check that we actually optimized + assert step > 5 + + # Check that energy decreased for both batches + assert torch.all(state.energy < initial_state.energy), ( + "BFGS optimization should reduce energy for all batches" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert torch.all(max_force < 0.2), ( + f"Forces should be small after optimization, got {max_force=}" + ) + + n_ar_atoms = ar_supercell_sim_state.n_atoms + assert not torch.allclose( + state.positions[:n_ar_atoms], multi_state.positions[:n_ar_atoms] + ) + assert not torch.allclose( + state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] + ) + + # We are evolving identical systems + assert torch.allclose(current_energy[0], current_energy[1]) + + +@pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) +def test_lbfgs_cell_optimization( + ar_supercell_sim_state: SimState, + lj_model: ModelInterface, + cell_filter: ts.CellFilter, +) -> None: + """Test that L-BFGS with cell filter actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + current_cell = ( + ar_supercell_sim_state.cell.clone() + + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=current_cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + initial_state_cell = current_sim_state.cell.clone() + + # Initialize L-BFGS optimizer with cell filter + state = ts.lbfgs_init( + state=current_sim_state, + model=lj_model, + cell_filter=cell_filter, + ) + + # Run optimization + energies = [1000.0, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.lbfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, ( + f"L-BFGS {cell_filter.name} optimization did not converge in {max_steps=}" + ) + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"L-BFGS {cell_filter.name} optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + pressure = torch.trace(state.stress.squeeze(0)) / 3.0 + + assert torch.abs(pressure) < 0.05, ( + f"Pressure should be small after {cell_filter.name} optimization, got {pressure=}" + ) + assert max_force < 0.3, ( + f"Forces should be small after {cell_filter.name} optimization, got {max_force=}" + ) + + assert not torch.allclose(state.positions, initial_state_positions, atol=1e-5), ( + f"L-BFGS {cell_filter.name} positions should have changed after optimization." + ) + assert not torch.allclose(state.cell, initial_state_cell, atol=1e-5), ( + f"L-BFGS {cell_filter.name} cell should have changed after optimization." + ) + + +def test_unit_cell_lbfgs_multi_batch( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test L-BFGS optimization with multiple batches.""" + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + multi_state = ts.concatenate_states( + [ar_supercell_sim_state_1, ar_supercell_sim_state_2], + device=ar_supercell_sim_state.device, + ) + + # Initialize L-BFGS optimizer with unit cell filter + state = ts.lbfgs_init( + state=multi_state, model=lj_model, cell_filter=ts.CellFilter.unit + ) + initial_state = copy.deepcopy(state) + + # Run optimization + prev_energy = torch.ones(2, device=state.device, dtype=state.energy.dtype) * 1000 + current_energy = initial_state.energy + step = 0 + while not torch.allclose(current_energy, prev_energy, atol=1e-9): + prev_energy = current_energy + state = ts.lbfgs_step(state=state, model=lj_model) + current_energy = state.energy + + step += 1 + if step > 500: + raise ValueError("L-BFGS optimization did not converge") + + # Check that we actually optimized + assert step > 5 + + # Check that energy decreased for both batches + assert torch.all(state.energy < initial_state.energy), ( + "L-BFGS optimization should reduce energy for all batches" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert torch.all(max_force < 0.2), ( + f"Forces should be small after optimization, got {max_force=}" + ) + + n_ar_atoms = ar_supercell_sim_state.n_atoms + assert not torch.allclose( + state.positions[:n_ar_atoms], multi_state.positions[:n_ar_atoms] + ) + assert not torch.allclose( + state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] + ) + + # We are evolving identical systems + assert torch.allclose(current_energy[0], current_energy[1]) + + @pytest.mark.parametrize( ("optimizer_fn", "expected_state_type"), [ (ts.Optimizer.fire, FireState), (ts.Optimizer.gradient_descent, OptimState), + (ts.Optimizer.bfgs, BFGSState), + (ts.Optimizer.lbfgs, LBFGSState), ], ) def test_simple_optimizer_init_with_dict( @@ -410,6 +808,10 @@ def test_unit_cell_fire_optimization( 50.0, ), (ts.Optimizer.fire, ts.CellFilter.frechet, ts.CellFireState, 75.0), + (ts.Optimizer.bfgs, ts.CellFilter.unit, ts.CellBFGSState, 100), + (ts.Optimizer.bfgs, ts.CellFilter.frechet, ts.CellBFGSState, 75.0), + (ts.Optimizer.lbfgs, ts.CellFilter.unit, ts.CellLBFGSState, 100), + (ts.Optimizer.lbfgs, ts.CellFilter.frechet, ts.CellLBFGSState, 75.0), ], ) def test_cell_optimizer_init_with_dict_and_cell_factor( @@ -462,6 +864,10 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( ts.CellFilter.frechet, ts.CellOptimState, ), + (ts.Optimizer.bfgs, ts.CellFilter.unit, ts.CellBFGSState), + (ts.Optimizer.bfgs, ts.CellFilter.frechet, ts.CellBFGSState), + (ts.Optimizer.lbfgs, ts.CellFilter.unit, ts.CellLBFGSState), + (ts.Optimizer.lbfgs, ts.CellFilter.frechet, ts.CellLBFGSState), ], ) def test_cell_optimizer_init_cell_factor_none( @@ -918,6 +1324,10 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> (ts.Optimizer.gradient_descent, None), (ts.Optimizer.fire, ts.CellFilter.unit), (ts.Optimizer.gradient_descent, ts.CellFilter.frechet), + (ts.Optimizer.bfgs, None), + (ts.Optimizer.lbfgs, None), + (ts.Optimizer.bfgs, ts.CellFilter.unit), + (ts.Optimizer.lbfgs, ts.CellFilter.frechet), ], ) def test_optimizer_preserves_charge_spin( @@ -956,8 +1366,10 @@ def test_optimizer_preserves_charge_spin( for _ in range(3): if optimizer_fn == ts.Optimizer.fire: opt_state = step_fn(state=opt_state, model=lj_model, dt_max=0.3) - else: + elif optimizer_fn == ts.Optimizer.gradient_descent: opt_state = step_fn(state=opt_state, model=lj_model, pos_lr=0.01, cell_lr=0.1) + else: + opt_state = step_fn(state=opt_state, model=lj_model) assert torch.allclose(opt_state.charge, original_charge) assert torch.allclose(opt_state.spin, original_spin) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 812375684..aa509138a 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -4,7 +4,9 @@ import pytest import torch from ase.filters import FrechetCellFilter, UnitCellFilter +from ase.optimize import BFGS as ASE_BFGS from ase.optimize import FIRE +from ase.optimize import LBFGS as ASE_LBFGS from pymatgen.analysis.structure_matcher import StructureMatcher import torch_sim as ts @@ -316,3 +318,338 @@ def test_optimizer_vs_ase_parametrized( tolerances=tolerances, test_id_prefix=test_id_prefix, ) + + +# TODO (AG): Can we merge these tests with the FIRE tests? + + +@pytest.mark.parametrize( + ( + "sim_state_fixture_name", + "cell_filter", + "ase_filter_cls", + "checkpoints", + "force_tol", + "tolerances", + "test_id_prefix", + ), + [ + ( + "rattled_sio2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS SiO2 (Frechet)", + ), + ( + "osn2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS OsN2 (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "BFGS Triclinic Al (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "BFGS Triclinic Al (UnitCell)", + ), + ( + "rattled_sio2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS SiO2 (UnitCell)", + ), + ( + "osn2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS OsN2 (UnitCell)", + ), + ], +) +def test_bfgs_vs_ase_parametrized( + sim_state_fixture_name: str, + cell_filter: ts.CellFilter, + ase_filter_cls: type, + checkpoints: list[int], + force_tol: float, + tolerances: dict[str, float], + test_id_prefix: str, + ts_mace_mpa: MaceModel, + ase_mace_mpa: "MACECalculator", + request: pytest.FixtureRequest, +) -> None: + """Compare torch-sim BFGS with ASE BFGS at multiple checkpoints.""" + pytest.importorskip("mace") + device = ts_mace_mpa.device + + initial_sim_state = request.getfixturevalue(sim_state_fixture_name) + state = initial_sim_state.clone() + + ase_atoms = ts.io.state_to_atoms( + initial_sim_state.clone().to(dtype=DTYPE, device=device) + )[0] + ase_atoms.calc = ase_mace_mpa + filtered_ase_atoms = ase_filter_cls(ase_atoms) + ase_optimizer = ASE_BFGS(filtered_ase_atoms, logfile=None, alpha=70.0) + + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=True + ) + + # Compare initial state + results = ts_mace_mpa(state) + ts_initial = state.clone() + ts_initial.forces = results["forces"] + ts_initial.energy = results["energy"] + ase_mace_mpa.calculate(ase_atoms) + _compare_ase_and_ts_states( + ts_initial, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Initial)" + ) + + last_step = 0 + for checkpoint in checkpoints: + steps = checkpoint - last_step + if steps > 0: + state = ts.optimize( + system=state, + model=ts_mace_mpa, + optimizer=ts.Optimizer.bfgs, + max_steps=steps, + convergence_fn=convergence_fn, + steps_between_swaps=1, + init_kwargs=dict(cell_filter=cell_filter), + ) + ase_optimizer.run(fmax=force_tol, steps=steps) + + _compare_ase_and_ts_states( + state, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Step {checkpoint})" + ) + last_step = checkpoint + + +# TODO (AG): Can we merge these tests with the FIRE tests? + + +@pytest.mark.parametrize( + ( + "sim_state_fixture_name", + "cell_filter", + "ase_filter_cls", + "checkpoints", + "force_tol", + "tolerances", + "test_id_prefix", + ), + [ + ( + "rattled_sio2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS SiO2 (Frechet)", + ), + ( + "osn2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS OsN2 (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "LBFGS Triclinic Al (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "LBFGS Triclinic Al (UnitCell)", + ), + ( + "rattled_sio2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS SiO2 (UnitCell)", + ), + ( + "osn2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS OsN2 (UnitCell)", + ), + ], +) +def test_lbfgs_vs_ase_parametrized( + sim_state_fixture_name: str, + cell_filter: ts.CellFilter, + ase_filter_cls: type, + checkpoints: list[int], + force_tol: float, + tolerances: dict[str, float], + test_id_prefix: str, + ts_mace_mpa: MaceModel, + ase_mace_mpa: "MACECalculator", + request: pytest.FixtureRequest, +) -> None: + """Compare torch-sim L-BFGS with ASE LBFGS at multiple checkpoints.""" + pytest.importorskip("mace") + device = ts_mace_mpa.device + + initial_sim_state = request.getfixturevalue(sim_state_fixture_name) + state = initial_sim_state.clone() + + ase_atoms = ts.io.state_to_atoms( + initial_sim_state.clone().to(dtype=DTYPE, device=device) + )[0] + ase_atoms.calc = ase_mace_mpa + filtered_ase_atoms = ase_filter_cls(ase_atoms) + ase_optimizer = ASE_LBFGS(filtered_ase_atoms, logfile=None, alpha=70.0, damping=1.0) + + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=True + ) + + # Compare initial state + results = ts_mace_mpa(state) + ts_initial = state.clone() + ts_initial.forces = results["forces"] + ts_initial.energy = results["energy"] + ase_mace_mpa.calculate(ase_atoms) + _compare_ase_and_ts_states( + ts_initial, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Initial)" + ) + + last_step = 0 + for checkpoint in checkpoints: + steps = checkpoint - last_step + if steps > 0: + state = ts.optimize( + system=state, + model=ts_mace_mpa, + optimizer=ts.Optimizer.lbfgs, + max_steps=steps, + convergence_fn=convergence_fn, + steps_between_swaps=1, + init_kwargs=dict(cell_filter=cell_filter, alpha=70.0, step_size=1.0), + max_step=0.2, + ) + ase_optimizer.run(fmax=force_tol, steps=steps) + + _compare_ase_and_ts_states( + state, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Step {checkpoint})" + ) + last_step = checkpoint diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index 8af199061..b56973170 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -53,18 +53,26 @@ from torch_sim.monte_carlo import SwapMCState, swap_mc_init, swap_mc_step from torch_sim.optimizers import ( OPTIM_REGISTRY, + BFGSState, FireState, + LBFGSState, Optimizer, OptimState, + bfgs_init, + bfgs_step, fire_init, fire_step, gradient_descent_init, gradient_descent_step, + lbfgs_init, + lbfgs_step, ) from torch_sim.optimizers.cell_filters import ( CELL_FILTER_REGISTRY, + CellBFGSState, CellFilter, CellFireState, + CellLBFGSState, CellOptimState, get_cell_filter, ) diff --git a/torch_sim/optimizers/__init__.py b/torch_sim/optimizers/__init__.py index 850cfcac5..7223e84cb 100644 --- a/torch_sim/optimizers/__init__.py +++ b/torch_sim/optimizers/__init__.py @@ -10,13 +10,25 @@ from enum import StrEnum from typing import Any, Final, Literal, get_args -from torch_sim.optimizers.cell_filters import CellFireState, CellOptimState # noqa: F401 +from torch_sim.optimizers.bfgs import bfgs_init, bfgs_step +from torch_sim.optimizers.cell_filters import ( # noqa: F401 + CellBFGSState, + CellFireState, + CellLBFGSState, + CellOptimState, +) from torch_sim.optimizers.fire import fire_init, fire_step from torch_sim.optimizers.gradient_descent import ( gradient_descent_init, gradient_descent_step, ) -from torch_sim.optimizers.state import FireState, OptimState # noqa: F401 +from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step +from torch_sim.optimizers.state import ( # noqa: F401 + BFGSState, + FireState, + LBFGSState, + OptimState, +) FireFlavor = Literal["vv_fire", "ase_fire"] @@ -28,9 +40,13 @@ class Optimizer(StrEnum): gradient_descent = "gradient_descent" fire = "fire" + lbfgs = "lbfgs" + bfgs = "bfgs" OPTIM_REGISTRY: Final[dict[Optimizer, tuple[Callable[..., Any], Callable[..., Any]]]] = { Optimizer.gradient_descent: (gradient_descent_init, gradient_descent_step), Optimizer.fire: (fire_init, fire_step), + Optimizer.lbfgs: (lbfgs_init, lbfgs_step), + Optimizer.bfgs: (bfgs_init, bfgs_step), } diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py new file mode 100644 index 000000000..ccf84c4d4 --- /dev/null +++ b/torch_sim/optimizers/bfgs.py @@ -0,0 +1,552 @@ +"""BFGS (Broyden-Fletcher-Goldfarb-Shanno) optimizer implementation. + +This module provides a batched BFGS optimizer that maintains the full Hessian +matrix for each system. This is suitable for systems with a small to moderate +number of atoms, where the $O(N^2)$ memory cost is acceptable. + +The implementation handles batches of systems with different numbers of atoms +by padding vectors to the maximum number of atoms in the batch. The Hessian +matrices are similarly padded to shape (n_systems, 3*max_atoms, 3*max_atoms). + +Note: When cell_filter is active, forces are transformed using the deformation gradient +to work in the same scaled coordinate space as ASE's UnitCellFilter/FrechetCellFilter. +The prev_forces and prev_positions are stored in the scaled/fractional space to match +ASE's behavior. +""" + +from typing import TYPE_CHECKING, Any + +import torch + +import torch_sim as ts +from torch_sim.optimizers import cell_filters +from torch_sim.optimizers.cell_filters import frechet_cell_filter_init +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import BFGSState, CellBFGSState + from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs + + +def _get_atom_indices_per_system( + system_idx: torch.Tensor, n_systems: int +) -> torch.Tensor: + """Compute the index of each atom within its system. + + Assumes atoms are grouped contiguously by system. + + Args: + system_idx: Tensor of system indices [n_atoms] + n_systems: Number of systems + + Returns: + Tensor of [0, 1, 2, ..., 0, 1, ...] [n_atoms] + """ + # We assume contiguous atoms for each system, which is standard in SimState + counts = torch.bincount(system_idx, minlength=n_systems) + # Create ranges [0...n-1] for each system and concatenate + indices = [torch.arange(c, device=system_idx.device) for c in counts] + return torch.cat(indices) + + +def _pad_to_dense( + flat_tensor: torch.Tensor, + system_idx: torch.Tensor, + atom_idx_in_system: torch.Tensor, + n_systems: int, + max_atoms: int, +) -> torch.Tensor: + """Convert a packed tensor to a padded dense tensor. + + Args: + flat_tensor: [n_atoms, D] + system_idx: [n_atoms] + atom_idx_in_system: [n_atoms] + n_systems: int + max_atoms: int + + Returns: + dense_tensor: [n_systems, max_atoms, D] + """ + D = flat_tensor.shape[1] + dense = torch.zeros( + (n_systems, max_atoms, D), dtype=flat_tensor.dtype, device=flat_tensor.device + ) + dense[system_idx, atom_idx_in_system] = flat_tensor + return dense + + +def bfgs_init( + state: SimState | StateDict, + model: "ModelInterface", + *, + max_step: float = 0.2, + alpha: float = 70.0, + cell_filter: "CellFilter | CellFilterFuncs | None" = None, + **filter_kwargs: Any, +) -> "BFGSState | CellBFGSState": + """Create an initial BFGSState. + + Initializes the Hessian as Identity matrix * alpha. + + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (max_atoms) + D = 3*M (position DOFs) + D_ext = 3*M + 9 (extended DOFs with cell) + + Args: + state: Input state + model: Model + max_step: Maximum step size (Angstrom) + alpha: Initial Hessian stiffness (eV/A^2) + cell_filter: Filter for cell optimization (None for position-only optimization) + **filter_kwargs: Additional arguments passed to cell filter initialization + + Returns: + BFGSState or CellBFGSState if cell_filter is provided + """ + from torch_sim.optimizers import BFGSState, CellBFGSState + + tensor_args = {"device": model.device, "dtype": model.dtype} + + if not isinstance(state, SimState): + state = SimState(**state) + + n_systems = state.n_systems # S + + counts = state.n_atoms_per_system # [S] + global_max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 # M + # Per-system max_atoms for padding/unpadding support + max_atoms = counts.clone() # [S] - each system's atom count + atom_idx = _get_atom_indices_per_system(state.system_idx, n_systems) # [N] + + model_output = model(state) + energy = model_output["energy"] # [S] + forces = model_output["forces"] # [N, 3] + stress = model_output.get("stress") # [S, 3, 3] or None + + alpha_t = torch.full((n_systems,), alpha, **tensor_args) # [S] + max_step_t = torch.full((n_systems,), max_step, **tensor_args) # [S] + n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) # [S] + + if cell_filter is not None: + # Extended Hessian: (3*global_max_atoms + 9) x (3*global_max_atoms + 9) + # The extra 9 DOFs are for cell parameters (3x3 matrix flattened) + dim = 3 * global_max_atoms + (3 * 3) # D_ext + hessian = ( + torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha + ) # [S, D_ext, D_ext] + + cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) + + # Note (AG): At initialization, deform_grad is identity, so we have: + # fractional = Cartesian / cell and scaled forces = forces @ I = forces + # For ASE compatibility, we need to store prev_positions as fractional coords + # and prev_forces as scaled forces + + # Get initial deform_grad (identity at start since reference_cell = current_cell) + reference_cell = state.cell.clone() # [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + reference_cell.mT, state.cell.mT + ) # [S, 3, 3] + + # Initial fractional positions = solve(deform_grad, positions) = positions + # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Initial scaled forces = forces @ deform_grad = forces + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] + scaled_forces = torch.bmm( + forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) + + common_args = { + "positions": state.positions.clone(), # [N, 3] + "masses": state.masses.clone(), # [N] + "cell": state.cell.clone(), # [S, 3, 3] + "atomic_numbers": state.atomic_numbers.clone(), # [N] + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "hessian": hessian, # [S, D_ext, D_ext] + # Note (AG): Store fractional positions and scaled forces + # for ASE compatibility + "prev_forces": scaled_forces, # [N, 3] (scaled) + "prev_positions": frac_positions, # [N, 3] (fractional) + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # scalar M + "system_idx": state.system_idx.clone(), # [N] + "pbc": state.pbc, # [S, 3] + "reference_cell": reference_cell, # [S, 3, 3] + "cell_filter": cell_filter_funcs, + "charge": state.charge, # preserve charge + "spin": state.spin, # preserve spin + "_constraints": state.constraints, # preserve constraints + } + + cell_state = CellBFGSState(**common_args) + + # Initialize cell-specific attributes (cell_positions, cell_forces, etc.) + # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] + init_fn(cell_state, model, **filter_kwargs) + + # Store prev_cell_positions and prev_cell_forces for Hessian update + cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] + cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + + return cell_state + + # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms + dim = 3 * global_max_atoms # D + hessian = ( + torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha + ) # [S, D, D] + + common_args = { + "positions": state.positions.clone(), # [N, 3] + "masses": state.masses.clone(), # [N] + "cell": state.cell.clone(), # [S, 3, 3] + "atomic_numbers": state.atomic_numbers.clone(), # [N] + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "hessian": hessian, # [S, D, D] + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # scalar M + "system_idx": state.system_idx.clone(), # [N] + "pbc": state.pbc, # [S, 3] + "charge": state.charge, # preserve charge + "spin": state.spin, # preserve spin + "_constraints": state.constraints, # preserve constraints + } + + return BFGSState(**common_args) + + +def bfgs_step( # noqa: C901, PLR0915 + state: "BFGSState | CellBFGSState", + model: "ModelInterface", +) -> "BFGSState | CellBFGSState": + """Perform one BFGS optimization step. + + Updates the Hessian estimate and moves atoms. If state is a CellBFGSState, + forces are transformed using the deformation gradient to work in the same + scaled coordinate space as ASE's cell filters (matching FIRE's approach). + + For cell optimization, prev_positions are stored as fractional coordinates + and prev_forces as scaled forces, exactly matching ASE's pos0/forces0. + + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (max_atoms) + D = 3*M (position DOFs) + D_ext = 3*M + 9 (extended DOFs with cell) + + Args: + state: Current optimization state + model: Calculator model + + Returns: + Updated state + """ + from torch_sim.optimizers import CellBFGSState + + # Note (AG): eps kept same as ASE's BFGS. + eps = 1e-7 + is_cell_state = isinstance(state, CellBFGSState) + + # Derive global_max_atoms from hessian shape + hessian_dim = state.hessian.shape[1] + global_max_atoms = (hessian_dim - 9) // 3 if is_cell_state else hessian_dim // 3 + + if is_cell_state: + # Get current deformation gradient + # reference_cell.mT: [S, 3, 3], row_vector_cell: [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + + # Transform forces to scaled coordinates + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] + forces_scaled = torch.bmm( + state.forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) # [N, 3] + + # Current fractional positions + # positions: [N, 3] -> frac_positions: [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Pack into dense tensors [N, 3] -> [S, M, 3] -> [S, D] + # For cell state, prev_positions is already fractional (stored that way) + # prev_forces is already scaled + # Note (AG): Optimization potential here. + forces_new = _pad_to_dense( + forces_scaled, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + forces_old = _pad_to_dense( + state.prev_forces, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_new = _pad_to_dense( + frac_positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_old = _pad_to_dense( + state.prev_positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + # Extend with cell DOFs: [S, 3, 3] -> [S, 9] + cell_pos_new = state.cell_positions.reshape(state.n_systems, 9) # [S, 9] + cell_forces_new = state.cell_forces.reshape(state.n_systems, 9) # [S, 9] + cell_pos_old = state.prev_cell_positions.reshape(state.n_systems, 9) # [S, 9] + cell_forces_old = state.prev_cell_forces.reshape(state.n_systems, 9) # [S, 9] + + # Concatenate: extended = [positions, cell_positions] + # [S, D] + [S, 9] -> [S, D_ext] + pos_new = torch.cat([pos_new, cell_pos_new], dim=1) # [S, D_ext] + forces_new = torch.cat([forces_new, cell_forces_new], dim=1) # [S, D_ext] + pos_old = torch.cat([pos_old, cell_pos_old], dim=1) # [S, D_ext] + forces_old = torch.cat([forces_old, cell_forces_old], dim=1) # [S, D_ext] + else: + forces_scaled = state.forces # [N, 3] + + # Pack into dense tensors [N, 3] -> [S, M, 3] -> [S, D] + forces_new = _pad_to_dense( + state.forces, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + forces_old = _pad_to_dense( + state.prev_forces, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_new = _pad_to_dense( + state.positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_old = _pad_to_dense( + state.prev_positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + # Calculate displacements and force changes + # dim = D or D_ext depending on cell_state + dpos = pos_new - pos_old # [S, dim] + dforces = forces_new - forces_old # [S, dim] + + # Identify systems with significant movement + max_disp = torch.max(torch.abs(dpos), dim=1).values # [S] + update_mask = max_disp >= eps # [S] bool + + # Update Hessian for active systems (BFGS update formula) + if update_mask.any(): + idx = update_mask + H = state.hessian[idx] # [S_active, dim, dim] + + dp = dpos[idx].unsqueeze(2) # [S_active, dim, 1] + df = dforces[idx].unsqueeze(2) # [S_active, dim, 1] # noqa: PD901 + + # a = dp^T @ df: [S_active, 1, dim] @ [S_active, dim, 1] -> [S_active, 1, 1] + a = torch.bmm(dp.transpose(1, 2), df).squeeze(2) # [S_active, 1] + # dg = H @ dp: [S_active, dim, dim] @ [S_active, dim, 1] -> [S_active, dim, 1] + dg = torch.bmm(H, dp) # [S_active, dim, 1] + # b = dp^T @ dg: [S_active, 1, dim] @ [S_active, dim, 1] -> [S_active, 1, 1] + b = torch.bmm(dp.transpose(1, 2), dg).squeeze(2) # [S_active, 1] + + # term1 = df @ df^T / a: [S_active, dim, dim] + term1 = torch.bmm(df, df.transpose(1, 2)) / (a.unsqueeze(2) + 1e-30) + # term2 = dg @ dg^T / b: [S_active, dim, dim] + term2 = torch.bmm(dg, dg.transpose(1, 2)) / (b.unsqueeze(2) + 1e-30) + + state.hessian[idx] = H - term1 - term2 # [S_active, dim, dim] + + # Calculate step direction using eigendecomposition + # step = V @ (|omega|^-1) @ V^T @ forces (pseudo-inverse via eigendecomposition) + # Note (AG): We use eigendecomposition rather than directly inverting H so we can + # take the absolute value of eigenvalues (|omega|). This ensures the step is always + # in a descent direction even if the Hessian approximation has negative eigenvalues. + + # Size-binned eigendecomposition: group systems by actual Hessian size + hessian_dim = state.hessian.shape[1] + step_dense = torch.zeros( + state.n_systems, hessian_dim, device=state.device, dtype=state.dtype + ) # [S, dim] + + # Get unique sizes and process each group with batched eigendecomp + # TODO(AG): If we sort and get the sizes before hand we can reduce the + # python loop overhead. + unique_sizes = state.max_atoms.unique() + + for size in unique_sizes: + actual_dim = int(3 * size.item()) + (9 if is_cell_state else 0) + mask = state.max_atoms == size # [S] bool - systems with this size + + # Extract actual-sized Hessians and forces for this group + H_group = state.hessian[mask, :actual_dim, :actual_dim] # [G, d, d] + f_group = forces_new[mask, :actual_dim] # [G, d] + + # Batched eigendecomposition on actual size (no padding overhead) + omega, V = torch.linalg.eigh(H_group) # omega: [G, d], V: [G, d, d] + abs_omega = torch.abs(omega).clamp(min=1e-30) # [G, d] + + # Compute step: V @ (V^T @ f / |omega|) + vt_f = torch.bmm(V.transpose(1, 2), f_group.unsqueeze(2)) # [G, d, 1] + step_group = torch.bmm(V, vt_f / abs_omega.unsqueeze(2)).squeeze(2) # [G, d] + + # Place results back into step_dense (padded to hessian_dim) + indices = mask.nonzero(as_tuple=True)[0] + step_dense[indices, :actual_dim] = step_group + + # Split step into position and cell components + atom_dim = 3 * global_max_atoms # D + if is_cell_state: + step_pos = step_dense[:, :atom_dim] # [S, D] + step_cell = step_dense[:, atom_dim:] # [S, 9] + else: + step_pos = step_dense # [S, D] + + # Scale step if it exceeds max_step + step_atoms = step_pos.view(state.n_systems, global_max_atoms, 3) # [S, M, 3] + atom_norms = torch.norm(step_atoms, dim=2) # [S, M] + + if is_cell_state: + step_cell_reshaped = step_cell.view(state.n_systems, 3, 3) # [S, 3, 3] + cell_norms = torch.norm(step_cell_reshaped, dim=2) # [S, 3] + all_norms = torch.cat([atom_norms, cell_norms], dim=1) # [S, M+3] + max_disp_per_sys = torch.max(all_norms, dim=1).values # [S] + else: + max_disp_per_sys = torch.max(atom_norms, dim=1).values # [S] + + scale = torch.ones_like(max_disp_per_sys) # [S] + needs_scale = max_disp_per_sys > state.max_step # [S] bool + scale[needs_scale] = state.max_step[needs_scale] / ( + max_disp_per_sys[needs_scale] + 1e-30 + ) + + step_pos = step_pos * scale.unsqueeze(1) # [S, D] + if is_cell_state: + step_cell = step_cell * scale.unsqueeze(1) # [S, 9] + + # Unpack dense step to flat: [S, M, 3] -> [N, 3] + flat_step = step_pos.view(state.n_systems, global_max_atoms, 3)[ + state.system_idx, state.atom_idx_in_system + ] # [N, 3] + + # Save previous state for next Hessian update + # For cell state: store fractional positions and scaled forces (ASE convention) + if is_cell_state: + state.prev_positions = frac_positions.clone() # [N, 3] (fractional) + state.prev_forces = forces_scaled.clone() # [N, 3] (scaled) + state.prev_cell_positions = state.cell_positions.clone() # [S, 3, 3] + state.prev_cell_forces = state.cell_forces.clone() # [S, 3, 3] + + # Apply cell step: [S, 9] -> [S, 3, 3] + dr_cell = step_cell.view(state.n_systems, 3, 3) # [S, 3, 3] + cell_positions_new = state.cell_positions + dr_cell # [S, 3, 3] + state.cell_positions = cell_positions_new # [S, 3, 3] + + # Determine if Frechet filter + init_fn, _step_fn = state.cell_filter + is_frechet = init_fn is frechet_cell_filter_init + + if is_frechet: + # Frechet: deform_grad = exp(cell_positions / cell_factor) + cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) + deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3] + deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3] + else: + # UnitCell: deform_grad = cell_positions / cell_factor + cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) + deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3] + + # Update cell: new_cell = reference_cell @ deform_grad^T + # reference_cell.mT: [S, 3, 3], deform_grad_new: [S, 3, 3] + state.row_vector_cell = torch.bmm( + state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + ) # [S, 3, 3] + + # Apply position step in fractional space, then convert to Cartesian + new_frac = frac_positions + flat_step # [N, 3] + + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + # new_positions = new_frac @ deform_grad^T + new_positions = torch.bmm( + new_frac.unsqueeze(1), # [N, 1, 3] + new_deform_grad[state.system_idx].transpose(-2, -1), # [N, 3, 3] + ).squeeze(1) # [N, 3] + state.set_constrained_positions(new_positions) # [N, 3] + else: + state.prev_positions = state.positions.clone() # [N, 3] + state.prev_forces = state.forces.clone() # [N, 3] + state.set_constrained_positions(state.positions + flat_step) # [N, 3] + + # Evaluate new forces and energy + model_output = model(state) + state.set_constrained_forces(model_output["forces"]) # [N, 3] + state.energy = model_output["energy"] # [S] + if "stress" in model_output: + state.stress = model_output["stress"] # [S, 3, 3] + + # Update cell forces for next step + # Update cell forces for cell state: [S, 3, 3] + if is_cell_state: + cell_filters.compute_cell_forces(model_output, state) + + state.n_iter += 1 + + return state diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 3ff0cf2de..3a061f23b 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -15,7 +15,7 @@ import torch_sim.math as fm from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers.state import FireState, OptimState +from torch_sim.optimizers.state import BFGSState, FireState, LBFGSState, OptimState from torch_sim.state import SimState @@ -309,7 +309,13 @@ def compute_cell_forces[T: AnyCellState]( state.cell_forces = cell_forces / state.cell_factor else: # Unit cell force computation - state.cell_forces = virial / state.cell_factor + # Note (AG): ASE transforms virial as: + # virial = np.linalg.solve(cur_deform_grad, virial.T).T + cur_deform_grad = deform_grad(state.reference_cell.mT, state.row_vector_cell) + virial_transformed = torch.linalg.solve( + cur_deform_grad, virial.transpose(-2, -1) + ).transpose(-2, -1) + state.cell_forces = virial_transformed / state.cell_factor CellFilterFuncs = tuple[Callable[..., None], Callable[..., None]] # (init_fn, update_fn) @@ -378,4 +384,60 @@ class CellFireState(CellOptimState, FireState): ) -AnyCellState = CellFireState | CellOptimState +@dataclass(kw_only=True) +class CellBFGSState(CellOptimState, BFGSState): + """State class for BFGS optimization with cell optimization. + + Combines BFGS position optimization with cell filter for simultaneous + optimization of atomic positions and unit cell parameters using a unified + extended coordinate space (positions + cell DOFs). + """ + + # Previous cell state for Hessian update + prev_cell_positions: torch.Tensor = field(default_factory=lambda: None) + prev_cell_forces: torch.Tensor = field(default_factory=lambda: None) + + _atom_attributes = ( + CellOptimState._atom_attributes # noqa: SLF001 + | BFGSState._atom_attributes # noqa: SLF001 + ) + _system_attributes = ( + CellOptimState._system_attributes # noqa: SLF001 + | BFGSState._system_attributes # noqa: SLF001 + | {"prev_cell_positions", "prev_cell_forces"} + ) + _global_attributes = ( + CellOptimState._global_attributes # noqa: SLF001 + | BFGSState._global_attributes # noqa: SLF001 + ) + + +@dataclass(kw_only=True) +class CellLBFGSState(CellOptimState, LBFGSState): + """State class for L-BFGS optimization with cell optimization. + + Combines L-BFGS position optimization with cell filter for simultaneous + optimization of atomic positions and unit cell parameters using a unified + extended coordinate space (positions + cell DOFs). + """ + + # Previous cell state for history update + prev_cell_positions: torch.Tensor = field(default_factory=lambda: None) + prev_cell_forces: torch.Tensor = field(default_factory=lambda: None) + + _atom_attributes = ( + CellOptimState._atom_attributes # noqa: SLF001 + | LBFGSState._atom_attributes # noqa: SLF001 + ) + _system_attributes = ( + CellOptimState._system_attributes # noqa: SLF001 + | LBFGSState._system_attributes # noqa: SLF001 + | {"prev_cell_positions", "prev_cell_forces"} + ) + _global_attributes = ( + CellOptimState._global_attributes # noqa: SLF001 + | LBFGSState._global_attributes # noqa: SLF001 + ) + + +AnyCellState = CellFireState | CellOptimState | CellBFGSState | CellLBFGSState diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py new file mode 100644 index 000000000..f4dd79a34 --- /dev/null +++ b/torch_sim/optimizers/lbfgs.py @@ -0,0 +1,622 @@ +"""L-BFGS (Limited-memory BFGS) optimizer implementation. + +This module provides a batched L-BFGS optimizer for atomic structure relaxation. +L-BFGS is a quasi-Newton method that approximates the inverse Hessian using +a limited history of position and gradient differences, making it memory-efficient +for large systems while achieving superlinear convergence near the minimum. + +When cell_filter is active, forces are transformed using the deformation gradient +to work in the same scaled coordinate space as ASE's UnitCellFilter/FrechetCellFilter. +The prev_forces and prev_positions are stored in the scaled/fractional space to match +ASE's behavior exactly. +""" + +from typing import TYPE_CHECKING, Any + +import torch + +import torch_sim as ts +from torch_sim.optimizers import cell_filters +from torch_sim.optimizers.cell_filters import frechet_cell_filter_init +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import CellLBFGSState, LBFGSState + from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs + + +def _atoms_to_padded( + x: torch.Tensor, + system_idx: torch.Tensor, + n_systems: int, + max_atoms: int, +) -> torch.Tensor: + """Convert atom-indexed [N, 3] to padded per-system [S, M, 3]. + + Args: + x: Tensor of shape [N, 3] where N = total atoms + system_idx: System index for each atom [N] + n_systems: Number of systems S + max_atoms: Maximum atoms per system M + + Returns: + Tensor of shape [S, M, 3] with zeros for padding + """ + device, dtype = x.device, x.dtype + out = torch.zeros((n_systems, max_atoms, 3), device=device, dtype=dtype) + # Create atom index within each system + atom_idx = torch.zeros_like(system_idx) + for sys in range(n_systems): + mask = system_idx == sys + atom_idx[mask] = torch.arange(mask.sum(), device=device) + out[system_idx, atom_idx] = x + return out + + +def _padded_to_atoms( + x: torch.Tensor, + system_idx: torch.Tensor, + n_atoms: int, +) -> torch.Tensor: + """Convert padded per-system [S, M, 3] to atom-indexed [N, 3]. + + Args: + x: Tensor of shape [S, M, 3] + system_idx: System index for each atom [N] + n_atoms: Total number of atoms N + + Returns: + Tensor of shape [N, 3] + """ + n_systems = x.shape[0] + device = x.device + # Create atom index within each system + atom_idx = torch.zeros(n_atoms, device=device, dtype=torch.long) + for sys in range(n_systems): + mask = system_idx == sys + atom_idx[mask] = torch.arange(mask.sum(), device=device) + return x[system_idx, atom_idx] # [N, 3] + + +def _per_system_vdot( + a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor +) -> torch.Tensor: + """Compute per-system dot product with padding mask. + + Args: + a: Tensor of shape [S, M, 3] + b: Tensor of shape [S, M, 3] + mask: Boolean mask [S, M] where True = valid atom + + Returns: + Tensor of shape [S] with per-system dot products + """ + # Element-wise product then sum over atoms and coordinates + prod = (a * b).sum(dim=-1) # [S, M] + prod = prod * mask.float() # Zero out padded atoms + return prod.sum(dim=-1) # [S] + + +def lbfgs_init( + state: SimState | StateDict, + model: "ModelInterface", + *, + step_size: float = 0.1, + alpha: float | None = None, + cell_filter: "CellFilter | CellFilterFuncs | None" = None, + **filter_kwargs: Any, +) -> "LBFGSState | CellLBFGSState": + r"""Create an initial LBFGSState from a SimState or state dict. + + Initializes forces/energy, clears the (s, y) memory, and broadcasts the + fixed step size to all systems. + + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (global_max_atoms) + H = history length (starts at 0) + M_ext = M + 3 (extended with cell DOFs per system) + + Args: + state: Input state as SimState object or state parameter dict + model: Model that computes energies, forces, and optionally stress + step_size: Fixed per-system step length (damping factor). + If using ASE mode (fixed alpha), set this to 1.0 (or your damping). + If using dynamic mode (default), 0.1 is a safe starting point. + alpha: Initial inverse Hessian stiffness guess (ASE parameter). + If provided (e.g. 70.0), fixes H0 = 1/alpha for all steps (ASE-style). + If None (default), H0 is updated dynamically (Standard L-BFGS). + cell_filter: Filter for cell optimization (None for position-only optimization) + **filter_kwargs: Additional arguments passed to cell filter initialization + + Returns: + LBFGSState with initialized optimization tensors, or CellLBFGSState if + cell_filter is provided + + Notes: + The optimizer supports two modes of operation: + 1. **Standard L-BFGS (default)**: Set `alpha=None`. The inverse Hessian + diagonal $H_0$ is updated dynamically at each step using the scaling + $\gamma_k = (s^T y) / (y^T y)$. This is the standard behavior described + by Nocedal & Wright. + 2. **ASE Compatibility Mode**: Set `alpha` (e.g. 70.0) and `step_size=1.0`. + The inverse Hessian diagonal is fixed at $H_0 = 1/\alpha$ throughout the + optimization, and the step is scaled by `step_size` (damping). + This matches `ase.optimize.LBFGS(alpha=70.0, damping=1.0)`. + """ + from torch_sim.optimizers import CellLBFGSState, LBFGSState + + tensor_args = {"device": model.device, "dtype": model.dtype} + + if not isinstance(state, SimState): + state = SimState(**state) + + n_systems = state.n_systems # S + + # Compute max atoms per system for per-system history storage + counts = state.n_atoms_per_system # [S] + global_max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 # M + max_atoms = counts.clone() # [S] - each system's atom count + + # Get initial forces and energy from model + model_output = model(state) + energy = model_output["energy"] # [S] + forces = model_output["forces"] # [N, 3] + stress = model_output.get("stress") # [S, 3, 3] or None + + # Initialize empty per-system history tensors + # History shape: [S, H, M, 3] where H=0 at start, M = global_max_atoms + s_history = torch.zeros( + (n_systems, 0, global_max_atoms, 3), **tensor_args + ) # [S, 0, M, 3] + y_history = torch.zeros( + (n_systems, 0, global_max_atoms, 3), **tensor_args + ) # [S, 0, M, 3] + + # Alpha tensor: 0.0 means dynamic, >0 means fixed + alpha_val = 0.0 if alpha is None else alpha + alpha_tensor = torch.full((n_systems,), alpha_val, **tensor_args) # [S] + + common_args = { + # Copy SimState attributes + "positions": state.positions.clone(), # [N, 3] + "masses": state.masses.clone(), # [N] + "cell": state.cell.clone(), # [S, 3, 3] + "atomic_numbers": state.atomic_numbers.clone(), # [N] + "system_idx": state.system_idx.clone(), # [N] + "pbc": state.pbc, # [S, 3] + "charge": state.charge, # preserve charge + "spin": state.spin, # preserve spin + "_constraints": state.constraints, # preserve constraints + # Optimization state + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + # L-BFGS specific state + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "s_history": s_history, # [S, 0, M, 3] + "y_history": y_history, # [S, 0, M, 3] + "step_size": torch.full((n_systems,), step_size, **tensor_args), # [S] + "alpha": alpha_tensor, # [S] + "n_iter": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), + "max_atoms": max_atoms, # [S] atoms per system for padding + } + + if cell_filter is not None: + cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) + + # At initialization, deform_grad is identity since reference_cell = current_cell + # Store prev_positions as fractional (same as Cartesian for identity deform_grad) + # Store prev_forces as scaled (same as Cartesian for identity deform_grad) + reference_cell = state.cell.clone() # [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + reference_cell.mT, state.cell.mT + ) # [S, 3, 3] + + # Initial fractional positions = positions + # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] -> [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Initial scaled forces = forces @ deform_grad = forces + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] + scaled_forces = torch.bmm( + forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) # [N, 3] + + common_args["reference_cell"] = reference_cell # [S, 3, 3] + common_args["cell_filter"] = cell_filter_funcs + # Store fractional positions and scaled forces for ASE compatibility + common_args["prev_positions"] = frac_positions # [N, 3] + common_args["prev_forces"] = scaled_forces # [N, 3] + + # Extended per-system history includes cell DOFs (3 "virtual atoms" per system) + # History shape: [S, H, M+3, 3] where M = global_max_atoms + extended_size_per_system = global_max_atoms + 3 # M_ext = M + 3 + common_args["s_history"] = torch.zeros( + (n_systems, 0, extended_size_per_system, 3), **tensor_args + ) # [S, 0, M_ext, 3] + common_args["y_history"] = torch.zeros( + (n_systems, 0, extended_size_per_system, 3), **tensor_args + ) # [S, 0, M_ext, 3] + + cell_state = CellLBFGSState(**common_args) + + # Initialize cell-specific attributes + # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] + init_fn(cell_state, model, **filter_kwargs) + + # Store prev_cell_positions and prev_cell_forces for history update + cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] + cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + + return cell_state + + return LBFGSState(**common_args) + + +def lbfgs_step( # noqa: PLR0915, C901 + state: "LBFGSState | CellLBFGSState", + model: "ModelInterface", + *, + max_history: int = 20, + max_step: float = 0.2, + curvature_eps: float = 1e-12, +) -> "LBFGSState | CellLBFGSState": + r"""Advance one L-BFGS iteration using the two-loop recursion. + + Computes the search direction via the two-loop recursion, applies a + fixed step with optional per-system capping, evaluates new forces and + energy, and updates the limited-memory history with a curvature check. + + When cell_filter is active, forces are transformed using the deformation + gradient to work in the same scaled coordinate space as ASE's cell filters. + The prev_positions are stored as fractional coordinates and prev_forces as + scaled forces, exactly matching ASE's pos0/forces0. + + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (history dimension) + H = current history length + M_ext = M + 3 (extended with cell DOFs per system) + + Args: + state: Current L-BFGS optimization state + model: Model that computes energies, forces, and optionally stress + max_history: Number of (s, y) pairs retained for the two-loop recursion. + max_step: If set, caps the maximum per-atom displacement per iteration. + curvature_eps: Threshold for the curvature ⟨y, s⟩ used to accept new + history pairs. + + Returns: + Updated LBFGSState after one optimization step + + Notes: + - If `state.alpha > 0` (ASE mode), the initial inverse Hessian estimate is + fixed at $H_0 = 1/\alpha$. + - Otherwise (Standard mode), $H_0$ varies at each step based on the + curvature of the most recent history pair. + + References: + - Nocedal & Wright, Numerical Optimization (L-BFGS two-loop recursion). + """ + from torch_sim.optimizers import CellLBFGSState + + is_cell_state = isinstance(state, CellLBFGSState) + device, dtype = model.device, model.dtype + eps = 1e-8 if dtype == torch.float32 else 1e-16 + n_systems = state.n_systems # S + n_atoms = state.n_atoms # N + + # Derive max_atoms from history shape: [S, H, M, 3] or [S, H, M_ext, 3] + history_dim = state.s_history.shape[2] # M or M_ext + if is_cell_state: + max_atoms_ext = history_dim # M_ext = M + 3 + max_atoms = max_atoms_ext - 3 # M + else: + max_atoms = history_dim # M + max_atoms_ext = max_atoms + + # Create atom index within each system for padding/unpadding + atom_idx_in_sys = torch.zeros(n_atoms, device=device, dtype=torch.long) + for sys in range(n_systems): + mask = state.system_idx == sys + atom_idx_in_sys[mask] = torch.arange(mask.sum(), device=device) + + # Create valid atom mask for per-system operations: [S, M] + atom_mask = torch.zeros((n_systems, max_atoms), device=device, dtype=torch.bool) + for sys in range(n_systems): + n_atoms_sys = int(state.max_atoms[sys].item()) + atom_mask[sys, :n_atoms_sys] = True + + # Extended mask including cell DOFs: [S, M_ext] + if is_cell_state: + ext_mask = torch.cat( + [ + atom_mask, + torch.ones((n_systems, 3), device=device, dtype=torch.bool), + ], + dim=1, + ) # [S, M_ext] + else: + ext_mask = atom_mask # [S, M] + + if is_cell_state: + # Get current deformation gradient + # reference_cell.mT: [S, 3, 3], row_vector_cell: [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + + # Transform forces to scaled coordinates + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] + forces_scaled = torch.bmm( + state.forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) # [N, 3] + + # Current fractional positions + # positions: [N, 3] -> frac_positions: [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Convert to padded per-system format: [S, M, 3] + g_atoms = _atoms_to_padded(-forces_scaled, state.system_idx, n_systems, max_atoms) + # Cell forces: [S, 3, 3] -> [S, 3, 3] + g_cell = -state.cell_forces # [S, 3, 3] + # Extended gradient: [S, M_ext, 3] = [S, M+3, 3] + g = torch.cat([g_atoms, g_cell], dim=1) # [S, M_ext, 3] + else: + # Convert to padded per-system format: [S, M, 3] + g = _atoms_to_padded(-state.forces, state.system_idx, n_systems, max_atoms) + + # Two-loop recursion to compute search direction d = -H_k g_k + # History shape: [S, H, M_ext, 3] or [S, H, M, 3] + cur_history_len = state.s_history.shape[1] # H + q = g.clone() # [S, M_ext, 3] or [S, M, 3] + alphas: list[torch.Tensor] = [] # list of [S] tensors + + # First loop (from newest to oldest) + for i in range(cur_history_len - 1, -1, -1): + s_i = state.s_history[:, i] # [S, M_ext, 3] or [S, M, 3] + y_i = state.y_history[:, i] # [S, M_ext, 3] or [S, M, 3] + + # ys = y^T s per system: [S] + ys = _per_system_vdot(y_i, s_i, ext_mask) # [S] + rho = torch.where( + ys.abs() > curvature_eps, + 1.0 / (ys + eps), + torch.zeros_like(ys), + ) # [S] + sq = _per_system_vdot(s_i, q, ext_mask) # [S] + alpha = rho * sq # [S] + alphas.append(alpha) + + # q <- q - alpha * y_i (broadcast alpha to [S, 1, 1]) + q = q - alpha.view(-1, 1, 1) * y_i # [S, M_ext, 3] + + # Initial H0 scaling: gamma = (s^T y)/(y^T y) using the last pair + if cur_history_len > 0: + s_last = state.s_history[:, -1] # [S, M_ext, 3] + y_last = state.y_history[:, -1] # [S, M_ext, 3] + sy = _per_system_vdot(s_last, y_last, ext_mask) # [S] + yy = _per_system_vdot(y_last, y_last, ext_mask) # [S] + gamma_dynamic = torch.where( + yy.abs() > curvature_eps, + sy / (yy + eps), + torch.ones_like(yy), + ) # [S] + else: + gamma_dynamic = torch.ones((n_systems,), device=device, dtype=dtype) # [S] + + # Fixed gamma (ASE style: 1/alpha) + # If state.alpha > 0, use that. Else use dynamic. + is_fixed = state.alpha > 1e-6 # [S] bool + gamma_fixed = 1.0 / (state.alpha + eps) # [S] + gamma = torch.where(is_fixed, gamma_fixed, gamma_dynamic) # [S] + + # z = gamma * q (broadcast gamma to [S, 1, 1]) + z = gamma.view(-1, 1, 1) * q # [S, M_ext, 3] + + # Second loop (from oldest to newest) + for i in range(cur_history_len): + s_i = state.s_history[:, i] # [S, M_ext, 3] + y_i = state.y_history[:, i] # [S, M_ext, 3] + + ys = _per_system_vdot(y_i, s_i, ext_mask) # [S] + rho = torch.where( + ys.abs() > curvature_eps, + 1.0 / (ys + eps), + torch.zeros_like(ys), + ) # [S] + yz = _per_system_vdot(y_i, z, ext_mask) # [S] + beta = rho * yz # [S] + + alpha_i = alphas[cur_history_len - 1 - i] # [S] + # z <- z + s_i * (alpha - beta) + coeff = (alpha_i - beta).view(-1, 1, 1) # [S, 1, 1] + z = z + s_i * coeff # [S, M_ext, 3] + + d = -z # search direction: [S, M_ext, 3] + + # Apply step_size scaling per system: [S, 1, 1] + step = state.step_size.view(-1, 1, 1) * d # [S, M_ext, 3] + + # Per-system max norm (only over valid atoms/DOFs) + step_norms = torch.linalg.norm(step, dim=-1) # [S, M_ext] + step_norms = step_norms * ext_mask.float() # Zero out padded + sys_max = step_norms.max(dim=1).values # [S] + + # Scaling factors per system: <= 1.0 + scale = torch.where( + sys_max > max_step, + max_step / (sys_max + eps), + torch.ones_like(sys_max), + ) # [S] + step = scale.view(-1, 1, 1) * step # [S, M_ext, 3] + + # Split step into position and cell components + if is_cell_state: + step_padded = step[:, :max_atoms] # [S, M, 3] + step_cell = step[:, max_atoms:] # [S, 3, 3] + # Convert padded step to atom-level + step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms) + else: + step_padded = step # [S, M, 3] + step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms) + + # Save previous state for history update + # For cell state: store fractional positions and scaled forces (ASE convention) + if is_cell_state: + state.prev_positions = frac_positions.clone() # [N, 3] (fractional) + state.prev_forces = forces_scaled.clone() # [N, 3] (scaled) + state.prev_cell_positions = state.cell_positions.clone() # [S, 3, 3] + state.prev_cell_forces = state.cell_forces.clone() # [S, 3, 3] + + # Apply cell step + dr_cell = step_cell # [S, 3, 3] + cell_positions_new = state.cell_positions + dr_cell # [S, 3, 3] + state.cell_positions = cell_positions_new # [S, 3, 3] + + # Determine if Frechet filter + init_fn, _step_fn = state.cell_filter + is_frechet = init_fn is frechet_cell_filter_init + + if is_frechet: + # Frechet: deform_grad = exp(cell_positions / cell_factor) + cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1) + deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3] + deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3] + else: + # UnitCell: deform_grad = cell_positions / cell_factor + cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) + deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3] + + # Update cell: new_cell = reference_cell @ deform_grad^T + # reference_cell.mT: [S, 3, 3], deform_grad_new: [S, 3, 3] + state.row_vector_cell = torch.bmm( + state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + ) # [S, 3, 3] + + # Apply position step in fractional space, then convert to Cartesian + new_frac = frac_positions + step_positions # [N, 3] + + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + # new_positions = new_frac @ deform_grad^T + new_positions = torch.bmm( + new_frac.unsqueeze(1), # [N, 1, 3] + new_deform_grad[state.system_idx].transpose(-2, -1), # [N, 3, 3] + ).squeeze(1) # [N, 3] + state.set_constrained_positions(new_positions) # [N, 3] + else: + state.prev_positions = state.positions.clone() # [N, 3] + state.prev_forces = state.forces.clone() # [N, 3] + state.set_constrained_positions(state.positions + step_positions) # [N, 3] + + # Evaluate new forces/energy + model_output = model(state) + new_forces = model_output["forces"] # [N, 3] + new_energy = model_output["energy"] # [S] + new_stress = model_output.get("stress") # [S, 3, 3] or None + + # Update cell forces for next step: [S, 3, 3] + if is_cell_state: + cell_filters.compute_cell_forces(model_output, state) + + # Build new (s, y) for history in per-system format [S, M_ext, 3] or [S, M, 3] + # s = position difference, y = gradient difference + if is_cell_state: + # Get new scaled forces and fractional positions for history + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + # new_forces: [N, 3] -> new_forces_scaled: [N, 3] + new_forces_scaled = torch.bmm( + new_forces.unsqueeze(1), # [N, 1, 3] + new_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) # [N, 3] + # positions: [N, 3] -> new_frac_positions: [N, 3] + new_frac_positions = torch.linalg.solve( + new_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # s_new_pos = frac_pos_new - frac_pos_old: [N, 3] -> [S, M, 3] + s_new_pos_atoms = new_frac_positions - state.prev_positions # [N, 3] + s_new_pos = _atoms_to_padded( + s_new_pos_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + # s_new_cell = cell_pos_new - cell_pos_old: [S, 3, 3] + s_new_cell = state.cell_positions - state.prev_cell_positions # [S, 3, 3] + # Concatenate to extended format: [S, M_ext, 3] + s_new = torch.cat([s_new_pos, s_new_cell], dim=1) # [S, M_ext, 3] + + # y_new = grad_diff for positions and cell (gradient = -forces) + # y = grad_new - grad_old = -forces_new - (-forces_old) = forces_old - forces_new + y_new_pos_atoms = -new_forces_scaled - (-state.prev_forces) # [N, 3] + y_new_pos = _atoms_to_padded( + y_new_pos_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + y_new_cell = -state.cell_forces - (-state.prev_cell_forces) # [S, 3, 3] + y_new = torch.cat([y_new_pos, y_new_cell], dim=1) # [S, M_ext, 3] + else: + # s_new = pos_new - pos_old: [N, 3] -> [S, M, 3] + s_new_atoms = state.positions - state.prev_positions # [N, 3] + s_new = _atoms_to_padded( + s_new_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + # y_new = grad_diff: [N, 3] -> [S, M, 3] + y_new_atoms = -new_forces - (-state.prev_forces) # [N, 3] + y_new = _atoms_to_padded( + y_new_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + + # Append history and trim if needed + # Note: ASE's L-BFGS doesn't have a curvature check for adding to history. + # Invalid curvatures are handled in the two-loop by checking rho. + # History tensors: [S, H, M_ext, 3] or [S, H, M, 3] + cur_history_len = state.s_history.shape[1] # H + if cur_history_len == 0: + # First entry: [S, 1, M_ext, 3] or [S, 1, M, 3] + s_hist = s_new.unsqueeze(1) # [S, 1, M_ext, 3] + y_hist = y_new.unsqueeze(1) # [S, 1, M_ext, 3] + else: + # Append new entry: [S, H, ...] cat [S, 1, ...] -> [S, H+1, ...] + s_hist = torch.cat([state.s_history, s_new.unsqueeze(1)], dim=1) + y_hist = torch.cat([state.y_history, y_new.unsqueeze(1)], dim=1) + # Trim to max_history + if s_hist.shape[1] > max_history: + s_hist = s_hist[:, -max_history:] # [S, max_history, ...] + y_hist = y_hist[:, -max_history:] + + # Update state + state.set_constrained_forces(new_forces) # [N, 3] + state.energy = new_energy # [S] + state.stress = new_stress # [S, 3, 3] or None + + if is_cell_state: + # Store fractional/scaled for next iteration + state.prev_positions = new_frac_positions.clone() # [N, 3] (fractional) + state.prev_forces = new_forces_scaled.clone() # [N, 3] (scaled) + else: + state.prev_forces = new_forces.clone() # [N, 3] + state.prev_positions = state.positions.clone() # [N, 3] + + state.s_history = s_hist # [S, H, M_ext, 3] or [S, H, M, 3] + state.y_history = y_hist # [S, H, M_ext, 3] or [S, H, M, 3] + state.n_iter = state.n_iter + 1 # [S] + + return state diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 358ea6b26..fb09795d0 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -1,6 +1,7 @@ """Optimizer state classes.""" from dataclasses import dataclass +from typing import ClassVar import torch @@ -50,4 +51,102 @@ class FireState(OptimState): _system_attributes = OptimState._system_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 +@dataclass(kw_only=True) +class BFGSState(OptimState): + """State for batched BFGS optimization. + + Stores the state needed to run a batched BFGS optimizer that maintains + an approximate Hessian matrix. + + Attributes: + hessian: Hessian matrix [n_systems, dim, dim] where dim = 3*max_atoms + for position-only or 3*max_atoms + 9 with cell filter. + May be padded when systems have different sizes. + prev_forces: Previous-step forces [n_atoms, 3]. For cell filter, + these are scaled forces (forces @ deform_grad) for ASE compatibility. + prev_positions: Previous-step positions [n_atoms, 3]. For cell filter, + these are fractional coordinates for ASE compatibility. + alpha: Initial Hessian scale (stiffness) [n_systems] + max_step: Maximum step size per atom [n_systems] + n_iter: Per-system iteration counter [n_systems] (int32) + atom_idx_in_system: Index of each atom within its system [n_atoms] + max_atoms: Atoms per system [n_systems] - used for size-binned eigendecomp + """ + + hessian: torch.Tensor + prev_forces: torch.Tensor + prev_positions: torch.Tensor + alpha: torch.Tensor + max_step: torch.Tensor + n_iter: torch.Tensor + atom_idx_in_system: torch.Tensor + max_atoms: torch.Tensor # Changed from int to Tensor for padding support + + _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 + "prev_forces", + "prev_positions", + "atom_idx_in_system", + } + _system_attributes = OptimState._system_attributes | { # noqa: SLF001 + "hessian", + "alpha", + "max_step", + "n_iter", + "max_atoms", + } + # Attributes that need padding when concatenating different-sized systems + _padded_system_attributes: ClassVar[set[str]] = {"hessian"} + + +@dataclass(kw_only=True) +class LBFGSState(OptimState): + """State for batched L-BFGS minimization (no line search). + + Stores the state needed to run a batched Limited-memory BFGS optimizer that + uses a fixed step size and the classical two-loop recursion to compute + approximate inverse-Hessian-vector products. All tensors are batched across + systems via `system_idx`. + + Attributes: + prev_forces: Previous-step forces [n_atoms, 3]. For cell filter, + these are scaled forces (forces @ deform_grad) for ASE compatibility. + prev_positions: Previous-step positions [n_atoms, 3]. For cell filter, + these are fractional coordinates for ASE compatibility. + s_history: Displacement history [n_systems, h, max_atoms, 3] per-system. + For cell filter: [n_systems, h, max_atoms + 3, 3] to include cell DOFs. + May be padded when systems have different sizes. + y_history: Gradient-diff history [n_systems, h, max_atoms, 3] per-system. + For cell filter: [n_systems, h, max_atoms + 3, 3] to include cell DOFs. + May be padded when systems have different sizes. + step_size: Per-system fixed step size [n_systems] + alpha: Initial inverse Hessian scale (stiffness) [n_systems] + n_iter: Per-system iteration counter [n_systems] (int32) + max_atoms: Atoms per system [n_systems] - used for size-binned operations + """ + + prev_forces: torch.Tensor + prev_positions: torch.Tensor + s_history: torch.Tensor + y_history: torch.Tensor + step_size: torch.Tensor + alpha: torch.Tensor + n_iter: torch.Tensor + max_atoms: torch.Tensor # [S] atoms per system for padding support + + _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 + "prev_forces", + "prev_positions", + } + _system_attributes = OptimState._system_attributes | { # noqa: SLF001 + "step_size", + "alpha", + "n_iter", + "max_atoms", + "s_history", + "y_history", + } + # Attributes that need padding when concatenating different-sized systems + _padded_system_attributes: ClassVar[set[str]] = {"s_history", "y_history"} + + # there's no GradientDescentState, it's the same as OptimState diff --git a/torch_sim/state.py b/torch_sim/state.py index 318454c79..f465938c4 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -848,6 +848,12 @@ def _split_state[T: SimState](state: T) -> list[T]: zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) for sys_idx in range(n_systems): + # Build per-system attributes (padded attributes stay padded for consistency) + per_system_dict = { + attr_name: split_per_system[attr_name][sys_idx] + for attr_name in split_per_system + } + system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( @@ -858,11 +864,8 @@ def _split_state[T: SimState](state: T) -> list[T]: attr_name: split_per_atom[attr_name][sys_idx] for attr_name in split_per_atom }, - # Add the split per-system attributes - **{ - attr_name: split_per_system[attr_name][sys_idx] - for attr_name in split_per_system - }, + # Add the split per-system attributes (with unpadding applied) + **per_system_dict, # Add the global attributes **global_attrs, } @@ -964,7 +967,7 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor return type(state)(**filtered_attrs) # type: ignore[invalid-return-type] -def concatenate_states[T: SimState]( # noqa: C901 +def concatenate_states[T: SimState]( # noqa: C901, PLR0915 states: Sequence[T], device: torch.device | None = None ) -> T: """Concatenate a list of SimStates into a single SimState. @@ -1039,10 +1042,57 @@ def concatenate_states[T: SimState]( # noqa: C901 # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) + # Get padded attributes if defined on the state class + padded_attrs = getattr(first_state, "_padded_system_attributes", set()) + for prop, tensors in per_system_tensors.items(): # if tensors: if isinstance(tensors[0], torch.Tensor): - concatenated[prop] = torch.cat(tensors, dim=0) + # TODO(AG): Is there a clean way to handle this? + if prop in padded_attrs: + # Pad tensors to max size before concatenating + # Detect tensor shape to determine padding strategy + first_tensor = tensors[0] + ndim = first_tensor.ndim + + if ndim == 3: + # Shape [S, D, D] required for BFGS hessian + # Pad last two dimensions + max_size = max(t.shape[-1] for t in tensors) + padded_tensors = [] + for t in tensors: + if t.shape[-1] < max_size: + pad_size = max_size - t.shape[-1] + t = torch.nn.functional.pad(t, (0, pad_size, 0, pad_size)) + padded_tensors.append(t) + concatenated[prop] = torch.cat(padded_tensors, dim=0) + elif ndim == 4: + # Shape [S, H, M, 3] required for L-BFGS history + # Pad dimension 2 (M) to max, and dimension 1 (H) to max + max_m = max(t.shape[2] for t in tensors) # max atoms dim + max_h = max(t.shape[1] for t in tensors) # max history dim + padded_tensors = [] + for t in tensors: + s_dim, h_dim, m_dim, last_dim = t.shape + if h_dim == 0: + # Special case: empty history, just create new shape + t = torch.zeros( + (s_dim, max_h, max_m, last_dim), + device=t.device, + dtype=t.dtype, + ) + elif m_dim < max_m or h_dim < max_h: + pad_m = max_m - m_dim + pad_h = max_h - h_dim + # For [S, H, M, 3]: pad M (dim 2) and H (dim 1) + t = torch.nn.functional.pad(t, (0, 0, 0, pad_m, 0, pad_h)) + padded_tensors.append(t) + concatenated[prop] = torch.cat(padded_tensors, dim=0) + else: + # Unknown shape, just concatenate without padding + concatenated[prop] = torch.cat(tensors, dim=0) + else: + concatenated[prop] = torch.cat(tensors, dim=0) else: # Non-tensor attributes, take first one (they should all be identical) concatenated[prop] = tensors[0]