From 4a4d43bd69ce7e8888593ab637b4b72f231d57b7 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 26 Nov 2025 08:58:32 -0800 Subject: [PATCH 01/17] Implement lbfgs --- .../2.8_MACE_LBFGS.py | 79 +++++ torch_sim/__init__.py | 3 + torch_sim/optimizers/__init__.py | 5 +- torch_sim/optimizers/lbfgs.py | 286 ++++++++++++++++++ torch_sim/optimizers/state.py | 39 +++ 5 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py create mode 100644 torch_sim/optimizers/lbfgs.py diff --git a/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py b/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py new file mode 100644 index 000000000..74c983d0f --- /dev/null +++ b/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py @@ -0,0 +1,79 @@ +"""Batched MACE L-BFGS optimizer with ASE comparison.""" + +# /// script +# dependencies = ["mace-torch>=0.3.12"] +# /// +import os + +import numpy as np +import torch +from ase.build import bulk +from ase.optimize import LBFGS as ASE_LBFGS +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + +loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) + +SMOKE_TEST = os.getenv("CI") is not None +N_steps = 10 if SMOKE_TEST else 200 + +rng = np.random.default_rng(seed=0) + +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +model = MaceModel( + model=loaded_model, + device=device, + compute_forces=True, + compute_stress=True, + dtype=dtype, + enable_cueq=False, +) + +# torch-sim batched L-BFGS +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +initial_results = model(state) +state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0) + +for _ in range(N_steps): + state = ts.lbfgs_step(state=state, model=model, max_history=100) + +ts_final = [e.item() for e in state.energy] + +# ASE L-BFGS comparison +ase_calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) +ase_final = [] +for atoms in atoms_list: + atoms.calc = ase_calc + optimizer = ASE_LBFGS(atoms, logfile=None) + optimizer.run(fmax=0.01, steps=N_steps) + ase_final.append(atoms.get_potential_energy()) + +# Results +print(f"Initial energies: {[f'{e.item():.4f}' for e in initial_results['energy']]}") +print(f"torch-sim final: {[f'{e:.4f}' for e in ts_final]}") +print(f"ASE final: {[f'{e:.4f}' for e in ase_final]}") diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index 3bc1bc922..cc239b27a 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -48,12 +48,15 @@ from torch_sim.optimizers import ( OPTIM_REGISTRY, FireState, + LBFGSState, Optimizer, OptimState, fire_init, fire_step, gradient_descent_init, gradient_descent_step, + lbfgs_init, + lbfgs_step, ) from torch_sim.optimizers.cell_filters import ( CELL_FILTER_REGISTRY, diff --git a/torch_sim/optimizers/__init__.py b/torch_sim/optimizers/__init__.py index 850cfcac5..39085d0e7 100644 --- a/torch_sim/optimizers/__init__.py +++ b/torch_sim/optimizers/__init__.py @@ -16,7 +16,8 @@ gradient_descent_init, gradient_descent_step, ) -from torch_sim.optimizers.state import FireState, OptimState # noqa: F401 +from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step +from torch_sim.optimizers.state import FireState, LBFGSState, OptimState # noqa: F401 FireFlavor = Literal["vv_fire", "ase_fire"] @@ -28,9 +29,11 @@ class Optimizer(StrEnum): gradient_descent = "gradient_descent" fire = "fire" + lbfgs = "lbfgs" OPTIM_REGISTRY: Final[dict[Optimizer, tuple[Callable[..., Any], Callable[..., Any]]]] = { Optimizer.gradient_descent: (gradient_descent_init, gradient_descent_step), Optimizer.fire: (fire_init, fire_step), + Optimizer.lbfgs: (lbfgs_init, lbfgs_step), } diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py new file mode 100644 index 000000000..fb617ca64 --- /dev/null +++ b/torch_sim/optimizers/lbfgs.py @@ -0,0 +1,286 @@ +"""L-BFGS (Limited-memory BFGS) optimizer implementation. + +This module provides a batched L-BFGS optimizer for atomic structure relaxation. +L-BFGS is a quasi-Newton method that approximates the inverse Hessian using +a limited history of position and gradient differences, making it memory-efficient +for large systems while achieving superlinear convergence near the minimum. +""" + +from typing import TYPE_CHECKING + +import torch + +import torch_sim.math as tsm +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import LBFGSState + + +def lbfgs_init( + state: SimState | StateDict, + model: "ModelInterface", + *, + step_size: float = 0.1, + alpha: float | None = None, +) -> "LBFGSState": + r"""Create an initial LBFGSState from a SimState or state dict. + + Initializes forces/energy, clears the (s, y) memory, and broadcasts the + fixed step size to all systems. + + Args: + state: Input state as SimState object or state parameter dict + model: Model that computes energies, forces, and optionally stress + step_size: Fixed per-system step length (damping factor). + If using ASE mode (fixed alpha), set this to 1.0 (or your damping). + If using dynamic mode (default), 0.1 is a safe starting point. + alpha: Initial inverse Hessian stiffness guess (ASE parameter). + If provided (e.g. 70.0), fixes H0 = 1/alpha for all steps (ASE-style). + If None (default), H0 is updated dynamically (Standard L-BFGS). + + Returns: + LBFGSState with initialized optimization tensors + + Notes: + The optimizer supports two modes of operation: + 1. **Standard L-BFGS (default)**: Set `alpha=None`. The inverse Hessian + diagonal $H_0$ is updated dynamically at each step using the scaling + $\gamma_k = (s^T y) / (y^T y)$. This is the standard behavior described + by Nocedal & Wright. + 2. **ASE Compatibility Mode**: Set `alpha` (e.g. 70.0) and `step_size=1.0`. + The inverse Hessian diagonal is fixed at $H_0 = 1/\alpha$ throughout the + optimization, and the step is scaled by `step_size` (damping). + This matches `ase.optimize.LBFGS(alpha=70.0, damping=1.0)`. + """ + from torch_sim.optimizers import LBFGSState + + tensor_args = {"device": model.device, "dtype": model.dtype} + + if not isinstance(state, SimState): + state = SimState(**state) + + n_systems = state.n_systems + + # Get initial forces and energy from model + model_output = model(state) + energy = model_output["energy"] + forces = model_output["forces"] + stress = model_output["stress"] + + # Initialize empty history tensors + # History shape: [max_history, n_atoms, 3] but we start with 0 entries + s_history = torch.zeros((0, state.n_atoms, 3), **tensor_args) + y_history = torch.zeros((0, state.n_atoms, 3), **tensor_args) + + # Alpha tensor: 0.0 means dynamic, >0 means fixed + alpha_val = 0.0 if alpha is None else alpha + alpha_tensor = torch.full((n_systems,), alpha_val, **tensor_args) + + return LBFGSState( + # Copy SimState attributes + positions=state.positions.clone(), + masses=state.masses.clone(), + cell=state.cell.clone(), + atomic_numbers=state.atomic_numbers.clone(), + system_idx=state.system_idx.clone(), + pbc=state.pbc, + # Optimization state + forces=forces, + energy=energy, + stress=stress, + # L-BFGS specific state + prev_forces=forces.clone(), + prev_positions=state.positions.clone(), + s_history=s_history, + y_history=y_history, + step_size=torch.full((n_systems,), step_size, **tensor_args), + alpha=alpha_tensor, + n_iter=torch.zeros((n_systems,), device=model.device, dtype=torch.int32), + ) + + +def lbfgs_step( # noqa: PLR0915 + state: "LBFGSState", + model: "ModelInterface", + *, + max_history: int = 10, + max_step: float = 0.2, + curvature_eps: float = 1e-12, +) -> "LBFGSState": + r"""Advance one L-BFGS iteration using the two-loop recursion. + + Computes the search direction via the two-loop recursion, applies a + fixed step with optional per-system capping, evaluates new forces and + energy, and updates the limited-memory history with a curvature check. + + Algorithm (per system s): + 1) Evaluate gradient g_k = ∇E(x_k) = -f(x_k) + 2) Perform L-BFGS two-loop recursion using up to `max_history` pairs + (s_i, y_i) to compute d_k = -H_k g_k + 3) Fixed step update with optional per-system step capping by `max_step` + 4) Curvature check and history update: accept (s_k, y_k) if ⟨y_k, s_k⟩ > ε + + Args: + state: Current L-BFGS optimization state + model: Model that computes energies, forces, and optionally stress + max_history: Number of (s, y) pairs retained for the two-loop recursion. + max_step: If set, caps the maximum per-atom displacement per iteration. + curvature_eps: Threshold for the curvature ⟨y, s⟩ used to accept new + history pairs. + + Returns: + Updated LBFGSState after one optimization step + + Notes: + - If `state.alpha > 0` (ASE mode), the initial inverse Hessian estimate is + fixed at $H_0 = 1/\alpha$. + - Otherwise (Standard mode), $H_0$ varies at each step based on the + curvature of the most recent history pair. + + References: + - Nocedal & Wright, Numerical Optimization (L-BFGS two-loop recursion). + """ + device, dtype = model.device, model.dtype + eps = 1e-8 if dtype == torch.float32 else 1e-16 + + # Current gradient + g = -state.forces + + # Two-loop recursion to compute search direction d = -H_k g_k + q = g.clone() + alphas: list[torch.Tensor] = [] # per-history, shape [n_systems] + + # First loop (from newest to oldest) + for i in range(state.s_history.shape[0] - 1, -1, -1): + s_i = state.s_history[i] + y_i = state.y_history[i] + + ys = tsm.batched_vdot(y_i, s_i, state.system_idx) # y^T s per system + rho = torch.where( + ys.abs() > curvature_eps, + 1.0 / (ys + eps), + torch.zeros_like(ys), + ) + sq = tsm.batched_vdot(s_i, q, state.system_idx) + alpha = rho * sq + alphas.append(alpha) + + # q <- q - alpha * y_i (broadcast per system to atoms) + alpha_atom = alpha[state.system_idx].unsqueeze(-1) + q = q - alpha_atom * y_i + + # Initial H0 scaling: gamma = (s^T y)/(y^T y) using the last pair + # Dynamic gamma (Standard L-BFGS) + if state.s_history.shape[0] > 0: + s_last = state.s_history[-1] + y_last = state.y_history[-1] + sy = tsm.batched_vdot(s_last, y_last, state.system_idx) + yy = tsm.batched_vdot(y_last, y_last, state.system_idx) + gamma_dynamic = torch.where( + yy.abs() > curvature_eps, + sy / (yy + eps), + torch.ones_like(yy), + ) + else: + gamma_dynamic = torch.ones((state.n_systems,), device=device, dtype=dtype) + + # Fixed gamma (ASE style: 1/alpha) + # If state.alpha > 0, use that. Else use dynamic. + is_fixed = state.alpha > 1e-6 + gamma_fixed = 1.0 / (state.alpha + eps) + gamma = torch.where(is_fixed, gamma_fixed, gamma_dynamic) + + z = gamma[state.system_idx].unsqueeze(-1) * q + + # Second loop (from oldest to newest) + for i in range(state.s_history.shape[0]): + s_i = state.s_history[i] + y_i = state.y_history[i] + + ys = tsm.batched_vdot(y_i, s_i, state.system_idx) + rho = torch.where( + ys.abs() > curvature_eps, + 1.0 / (ys + eps), + torch.zeros_like(ys), + ) + yz = tsm.batched_vdot(y_i, z, state.system_idx) + beta = rho * yz + + alpha = alphas[state.s_history.shape[0] - 1 - i] + # z <- z + s_i * (alpha - beta) + coeff = (alpha - beta)[state.system_idx].unsqueeze(-1) + z = z + coeff * s_i + + d = -z # search direction + + # Optional per-system max step cap + # Compute per-atom step with current step_size + t_atoms = state.step_size[state.system_idx].unsqueeze(-1) + step = t_atoms * d + + # Per-atom norms + norms = torch.linalg.norm(step, dim=1) + + # Per-system max norm + sys_max = torch.zeros(state.n_systems, device=device, dtype=dtype) + sys_max.scatter_reduce_(0, state.system_idx, norms, reduce="amax", include_self=False) + + # Scaling factors per system: <= 1.0 + scale = torch.where( + sys_max > max_step, + max_step / (sys_max + eps), + torch.ones_like(sys_max), + ) + scale_atoms = scale[state.system_idx].unsqueeze(-1) + step = scale_atoms * step + + # Update positions + new_positions = state.positions + step + + # Evaluate new forces/energy + state.positions = new_positions + model_output = model(state) + new_forces = model_output["forces"] + new_energy = model_output["energy"] + new_stress = model_output("stress") + + # Build new (s, y) + s_new = state.positions - state.prev_positions + y_new = -new_forces - (-state.prev_forces) # g_new - g_prev = -(f_new - f_prev) + + # Curvature check per system; if bad, clear history (conservative) + sy = tsm.batched_vdot(s_new, y_new, state.system_idx) + bad_curv = sy <= curvature_eps + + if bad_curv.any(): + # Clear entire history to preserve correctness + s_hist = torch.zeros((0, state.n_atoms, 3), device=device, dtype=dtype) + y_hist = torch.zeros((0, state.n_atoms, 3), device=device, dtype=dtype) + else: + # Append and trim if needed + if state.s_history.shape[0] == 0: + s_hist = s_new.unsqueeze(0) + y_hist = y_new.unsqueeze(0) + else: + s_hist = torch.cat([state.s_history, s_new.unsqueeze(0)], dim=0) + y_hist = torch.cat([state.y_history, y_new.unsqueeze(0)], dim=0) + if s_hist.shape[0] > max_history: + s_hist = s_hist[-max_history:] + y_hist = y_hist[-max_history:] + + # Update state + state.forces = new_forces + state.energy = new_energy + state.stress = new_stress + + state.prev_forces = new_forces.clone() + state.prev_positions = state.positions.clone() + state.s_history = s_hist + state.y_history = y_hist + state.n_iter = state.n_iter + 1 + + return state diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 2ab530db4..5946a173f 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -40,4 +40,43 @@ class FireState(OptimState): _system_attributes = OptimState._system_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 +@dataclass(kw_only=True) +class LBFGSState(OptimState): + """State for batched L-BFGS minimization (no line search). + + Stores the state needed to run a batched Limited-memory BFGS optimizer that + uses a fixed step size and the classical two-loop recursion to compute + approximate inverse-Hessian-vector products. All tensors are batched across + systems via `system_idx`. + + Attributes: + prev_forces: Previous-step forces [n_atoms, 3] + prev_positions: Previous-step positions [n_atoms, 3] + s_history: Displacement history [h, n_atoms, 3] + y_history: Gradient-diff history [h, n_atoms, 3] + step_size: Per-system fixed step size [n_systems] + n_iter: Per-system iteration counter [n_systems] (int32) + """ + + prev_forces: torch.Tensor + prev_positions: torch.Tensor + s_history: torch.Tensor + y_history: torch.Tensor + step_size: torch.Tensor + alpha: torch.Tensor + n_iter: torch.Tensor + + _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 + "prev_forces", + "prev_positions", + } + _system_attributes = OptimState._system_attributes | { # noqa: SLF001 + "s_history", + "y_history", + "step_size", + "alpha", + "n_iter", + } + + # there's no GradientDescentState, it's the same as OptimState From 793de836731bef6f47f29c4f4a220efc8303d80a Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 26 Nov 2025 09:45:12 -0800 Subject: [PATCH 02/17] fix example --- torch_sim/optimizers/lbfgs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index fb617ca64..6c6408d31 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -246,7 +246,7 @@ def lbfgs_step( # noqa: PLR0915 model_output = model(state) new_forces = model_output["forces"] new_energy = model_output["energy"] - new_stress = model_output("stress") + new_stress = model_output["stress"] # Build new (s, y) s_new = state.positions - state.prev_positions From b4c692f659c512e99bcb41780d806f0a999e2783 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 26 Nov 2025 11:02:05 -0800 Subject: [PATCH 03/17] Implement bfgs --- .../2.9_MACE_BFGS.py | 79 +++++ torch_sim/__init__.py | 3 + torch_sim/optimizers/__init__.py | 10 +- torch_sim/optimizers/bfgs.py | 282 ++++++++++++++++++ torch_sim/optimizers/state.py | 41 +++ 5 files changed, 414 insertions(+), 1 deletion(-) create mode 100644 examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py create mode 100644 torch_sim/optimizers/bfgs.py diff --git a/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py b/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py new file mode 100644 index 000000000..138ced4bf --- /dev/null +++ b/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py @@ -0,0 +1,79 @@ +"""Batched MACE BFGS optimizer with ASE comparison.""" + +# /// script +# dependencies = ["mace-torch>=0.3.12"] +# /// +import os + +import numpy as np +import torch +from ase.build import bulk +from ase.optimize import BFGS as ASE_BFGS +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + +loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) + +SMOKE_TEST = os.getenv("CI") is not None +N_steps = 10 if SMOKE_TEST else 200 + +rng = np.random.default_rng(seed=0) + +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +model = MaceModel( + model=loaded_model, + device=device, + compute_forces=True, + compute_stress=True, + dtype=dtype, + enable_cueq=False, +) + +# torch-sim batched BFGS +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +initial_results = model(state) +state = ts.bfgs_init(state=state, model=model, alpha=70.0) + +for _ in range(N_steps): + state = ts.bfgs_step(state=state, model=model) + +ts_final = [e.item() for e in state.energy] + +# ASE BFGS comparison +ase_calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) +ase_final = [] +for atoms in atoms_list: + atoms.calc = ase_calc + optimizer = ASE_BFGS(atoms, logfile=None, alpha=70.0) + optimizer.run(fmax=0.01, steps=N_steps) + ase_final.append(atoms.get_potential_energy()) + +# Results +print(f"Initial energies: {[f'{e.item():.4f}' for e in initial_results['energy']]}") +print(f"torch-sim final: {[f'{e:.4f}' for e in ts_final]}") +print(f"ASE final: {[f'{e:.4f}' for e in ase_final]}") diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index cc239b27a..b106ea2e2 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -47,10 +47,13 @@ from torch_sim.monte_carlo import SwapMCState, swap_mc_init, swap_mc_step from torch_sim.optimizers import ( OPTIM_REGISTRY, + BFGSState, FireState, LBFGSState, Optimizer, OptimState, + bfgs_init, + bfgs_step, fire_init, fire_step, gradient_descent_init, diff --git a/torch_sim/optimizers/__init__.py b/torch_sim/optimizers/__init__.py index 39085d0e7..3a8d4fafe 100644 --- a/torch_sim/optimizers/__init__.py +++ b/torch_sim/optimizers/__init__.py @@ -10,6 +10,7 @@ from enum import StrEnum from typing import Any, Final, Literal, get_args +from torch_sim.optimizers.bfgs import bfgs_init, bfgs_step from torch_sim.optimizers.cell_filters import CellFireState, CellOptimState # noqa: F401 from torch_sim.optimizers.fire import fire_init, fire_step from torch_sim.optimizers.gradient_descent import ( @@ -17,7 +18,12 @@ gradient_descent_step, ) from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step -from torch_sim.optimizers.state import FireState, LBFGSState, OptimState # noqa: F401 +from torch_sim.optimizers.state import ( # noqa: F401 + BFGSState, + FireState, + LBFGSState, + OptimState, +) FireFlavor = Literal["vv_fire", "ase_fire"] @@ -30,10 +36,12 @@ class Optimizer(StrEnum): gradient_descent = "gradient_descent" fire = "fire" lbfgs = "lbfgs" + bfgs = "bfgs" OPTIM_REGISTRY: Final[dict[Optimizer, tuple[Callable[..., Any], Callable[..., Any]]]] = { Optimizer.gradient_descent: (gradient_descent_init, gradient_descent_step), Optimizer.fire: (fire_init, fire_step), Optimizer.lbfgs: (lbfgs_init, lbfgs_step), + Optimizer.bfgs: (bfgs_init, bfgs_step), } diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py new file mode 100644 index 000000000..bb7021c5d --- /dev/null +++ b/torch_sim/optimizers/bfgs.py @@ -0,0 +1,282 @@ +"""BFGS (Broyden-Fletcher-Goldfarb-Shanno) optimizer implementation. + +This module provides a batched BFGS optimizer that maintains the full Hessian +matrix for each system. This is suitable for systems with a small to moderate +number of atoms, where the $O(N^2)$ memory cost is acceptable. + +The implementation handles batches of systems with different numbers of atoms +by padding vectors to the maximum number of atoms in the batch. The Hessian +matrices are similarly padded to shape (n_systems, 3*max_atoms, 3*max_atoms). +""" + +from typing import TYPE_CHECKING + +import torch + +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import BFGSState + + +def _get_atom_indices_per_system( + system_idx: torch.Tensor, n_systems: int +) -> torch.Tensor: + """Compute the index of each atom within its system. + + Assumes atoms are grouped contiguously by system. + + Args: + system_idx: Tensor of system indices [n_atoms] + n_systems: Number of systems + + Returns: + Tensor of [0, 1, 2, ..., 0, 1, ...] [n_atoms] + """ + # We assume contiguous atoms for each system, which is standard in SimState + counts = torch.bincount(system_idx, minlength=n_systems) + # Create ranges [0...n-1] for each system and concatenate + indices = [torch.arange(c, device=system_idx.device) for c in counts] + return torch.cat(indices) + + +def _pad_to_dense( + flat_tensor: torch.Tensor, + system_idx: torch.Tensor, + atom_idx_in_system: torch.Tensor, + n_systems: int, + max_atoms: int, +) -> torch.Tensor: + """Convert a packed tensor to a padded dense tensor. + + Args: + flat_tensor: [n_atoms, D] + system_idx: [n_atoms] + atom_idx_in_system: [n_atoms] + n_systems: int + max_atoms: int + + Returns: + dense_tensor: [n_systems, max_atoms, D] + """ + D = flat_tensor.shape[1] + dense = torch.zeros( + (n_systems, max_atoms, D), dtype=flat_tensor.dtype, device=flat_tensor.device + ) + dense[system_idx, atom_idx_in_system] = flat_tensor + return dense + + +def bfgs_init( + state: SimState | StateDict, + model: "ModelInterface", + *, + max_step: float = 0.2, + alpha: float = 70.0, +) -> "BFGSState": + """Create an initial BFGSState. + + Initializes the Hessian as Identity * alpha. + + Args: + state: Input state + model: Model + max_step: Maximum step size (Angstrom) + alpha: Initial Hessian stiffness (eV/A^2) + + Returns: + BFGSState + """ + from torch_sim.optimizers import BFGSState + + tensor_args = {"device": model.device, "dtype": model.dtype} + + if not isinstance(state, SimState): + state = SimState(**state) + + n_systems = state.n_systems + + counts = state.n_atoms_per_system + max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 + atom_idx = _get_atom_indices_per_system(state.system_idx, n_systems) + + model_output = model(state) + energy = model_output["energy"] + forces = model_output["forces"] + stress = model_output["stress"] + + # shape: (n_systems, 3*max_atoms, 3*max_atoms) + dim = 3 * max_atoms + hessian = torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha + + alpha_t = torch.full((n_systems,), alpha, **tensor_args) + max_step_t = torch.full((n_systems,), max_step, **tensor_args) + n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) + + return BFGSState( + positions=state.positions.clone(), + masses=state.masses.clone(), + cell=state.cell.clone(), + atomic_numbers=state.atomic_numbers.clone(), + forces=forces, + energy=energy, + stress=stress, + hessian=hessian, + prev_forces=forces.clone(), + prev_positions=state.positions.clone(), + alpha=alpha_t, + max_step=max_step_t, + n_iter=n_iter, + atom_idx_in_system=atom_idx, + max_atoms=max_atoms, + # passed to __post_init__ + system_idx=state.system_idx.clone(), + pbc=state.pbc, + ) + + +def bfgs_step( + state: "BFGSState", + model: "ModelInterface", +) -> "BFGSState": + """Perform one BFGS optimization step. + + Updates the Hessian estimate and moves atoms. + + Args: + state: Current optimization state + model: Calculator model + + Returns: + Updated state + """ + eps = 1e-7 + + # Pack flat tensors into dense batched tensors + # shape: (n_systems, max_atoms * 3) + pos_new = _pad_to_dense( + state.positions, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + forces_new = _pad_to_dense( + state.forces, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + pos_old = _pad_to_dense( + state.prev_positions, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + forces_old = _pad_to_dense( + state.prev_forces, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + # Calculate displacements and force changes + # dpos: (n_systems, max_atoms * 3) + dpos = pos_new - pos_old + dforces = -(forces_new - forces_old) + + # Identify systems with significant movement + max_disp = torch.max(torch.abs(dpos), dim=1).values + update_mask = max_disp >= eps + + # Update Hessian for active systems + if update_mask.any(): + idx = update_mask + H = state.hessian[idx] + + # shape: (n_active, D, 1) + dp = dpos[idx].unsqueeze(2) + df = dforces[idx].unsqueeze(2) # noqa: PD901 + + # shape: (n_active, 1) + a = torch.bmm(dp.transpose(1, 2), df).squeeze(2) + + # shape: (n_active, D, 1) + dg = torch.bmm(H, dp) + + # shape: (n_active, 1) + b = torch.bmm(dp.transpose(1, 2), dg).squeeze(2) + + # Rank-2 update + # shape: (n_active, D, D) + term1 = torch.bmm(df, df.transpose(1, 2)) / (a.unsqueeze(2) + 1e-30) + term2 = torch.bmm(dg, dg.transpose(1, 2)) / (b.unsqueeze(2) + 1e-30) + + state.hessian[idx] = H - term1 - term2 + + # Calculate step direction using eigendecomposition + # gradient: (n_systems, D, 1) + # Step p = H^-1 * F + direction = forces_new.unsqueeze(2) + + # omega: (n_systems, D), V: (n_systems, D, D) + omega, V = torch.linalg.eigh(state.hessian) + + # shape: (n_systems, 1, D) + abs_omega = torch.abs(omega).unsqueeze(1) + abs_omega = torch.where(abs_omega < 1e-30, torch.ones_like(abs_omega), abs_omega) + + # Project direction onto eigenvectors and scale + # shape: (n_systems, D, 1) + vt_g = torch.bmm(V.transpose(1, 2), direction) + scaled = vt_g / abs_omega.transpose(1, 2) + + # Transform back to original basis + # shape: (n_systems, D) + step_dense = torch.bmm(V, scaled).squeeze(2) + + # Scale step if it exceeds max_step + # step_atoms: (n_systems, max_atoms, 3) + step_atoms = step_dense.view(state.n_systems, state.max_atoms, 3) + # atom_norms: (n_systems, max_atoms) + atom_norms = torch.norm(step_atoms, dim=2) + + # max_disp_per_sys: (n_systems,) + max_disp_per_sys = torch.max(atom_norms, dim=1).values + + scale = torch.ones_like(max_disp_per_sys) + needs_scale = max_disp_per_sys > state.max_step + scale[needs_scale] = state.max_step[needs_scale] / ( + max_disp_per_sys[needs_scale] + 1e-30 + ) + + # shape: (n_systems, D) + step_dense = step_dense * scale.unsqueeze(1) + + # Unpack dense step back to flat valid atoms + flat_step = step_dense.view(state.n_systems, state.max_atoms, 3)[ + state.system_idx, state.atom_idx_in_system + ] + + new_positions = state.positions + flat_step + + state.prev_positions = state.positions.clone() + state.prev_forces = state.forces.clone() + state.positions = new_positions + + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + state.n_iter += 1 + + return state diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 5946a173f..92ace560c 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -40,6 +40,47 @@ class FireState(OptimState): _system_attributes = OptimState._system_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 +@dataclass(kw_only=True) +class BFGSState(OptimState): + """State for batched BFGS optimization. + + Stores the state needed to run a batched BFGS optimizer that maintains + an approximate Hessian or inverse Hessian. + + Attributes: + hessian: Hessian matrix [n_systems, 3*max_atoms, 3*max_atoms] + prev_forces: Previous-step forces [n_atoms, 3] + prev_positions: Previous-step positions [n_atoms, 3] + alpha: Initial Hessian scale [n_systems] + max_step: Maximum step size [n_systems] + n_iter: Per-system iteration counter [n_systems] (int32) + atom_idx_in_system: Index of each atom within its system [n_atoms] + max_atoms: Maximum number of atoms in any system (int) + """ + + hessian: torch.Tensor + prev_forces: torch.Tensor + prev_positions: torch.Tensor + alpha: torch.Tensor + max_step: torch.Tensor + n_iter: torch.Tensor + atom_idx_in_system: torch.Tensor + max_atoms: int + + _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 + "prev_forces", + "prev_positions", + "atom_idx_in_system", + } + _system_attributes = OptimState._system_attributes | { # noqa: SLF001 + "hessian", + "alpha", + "max_step", + "n_iter", + "max_atoms", + } + + @dataclass(kw_only=True) class LBFGSState(OptimState): """State for batched L-BFGS minimization (no line search). From b41a8ef7681da881f984104451492cf8bec73297 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 9 Jan 2026 10:12:18 -0500 Subject: [PATCH 04/17] clean scripts per #385 --- .../2.8_MACE_LBFGS.py | 79 ------------------- .../2.9_MACE_BFGS.py | 79 ------------------- examples/scripts/2_structural_optimization.py | 78 ++++++++++++++++-- torch_sim/optimizers/lbfgs.py | 2 +- 4 files changed, 73 insertions(+), 165 deletions(-) delete mode 100644 examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py delete mode 100644 examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py diff --git a/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py b/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py deleted file mode 100644 index 74c983d0f..000000000 --- a/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Batched MACE L-BFGS optimizer with ASE comparison.""" - -# /// script -# dependencies = ["mace-torch>=0.3.12"] -# /// -import os - -import numpy as np -import torch -from ase.build import bulk -from ase.optimize import LBFGS as ASE_LBFGS -from mace.calculators.foundations_models import mace_mp - -import torch_sim as ts -from torch_sim.models.mace import MaceModel, MaceUrls - - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -dtype = torch.float32 - -loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, - return_raw_model=True, - default_dtype=str(dtype).removeprefix("torch."), - device=str(device), -) - -SMOKE_TEST = os.getenv("CI") is not None -N_steps = 10 if SMOKE_TEST else 200 - -rng = np.random.default_rng(seed=0) - -si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) -si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) - -cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) -cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) - -fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) -fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) - -atoms_list = [si_dc, cu_dc, fe_dc] - -model = MaceModel( - model=loaded_model, - device=device, - compute_forces=True, - compute_stress=True, - dtype=dtype, - enable_cueq=False, -) - -# torch-sim batched L-BFGS -state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) -initial_results = model(state) -state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0) - -for _ in range(N_steps): - state = ts.lbfgs_step(state=state, model=model, max_history=100) - -ts_final = [e.item() for e in state.energy] - -# ASE L-BFGS comparison -ase_calc = mace_mp( - model=MaceUrls.mace_mpa_medium, - default_dtype=str(dtype).removeprefix("torch."), - device=str(device), -) -ase_final = [] -for atoms in atoms_list: - atoms.calc = ase_calc - optimizer = ASE_LBFGS(atoms, logfile=None) - optimizer.run(fmax=0.01, steps=N_steps) - ase_final.append(atoms.get_potential_energy()) - -# Results -print(f"Initial energies: {[f'{e.item():.4f}' for e in initial_results['energy']]}") -print(f"torch-sim final: {[f'{e:.4f}' for e in ts_final]}") -print(f"ASE final: {[f'{e:.4f}' for e in ase_final]}") diff --git a/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py b/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py deleted file mode 100644 index 138ced4bf..000000000 --- a/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Batched MACE BFGS optimizer with ASE comparison.""" - -# /// script -# dependencies = ["mace-torch>=0.3.12"] -# /// -import os - -import numpy as np -import torch -from ase.build import bulk -from ase.optimize import BFGS as ASE_BFGS -from mace.calculators.foundations_models import mace_mp - -import torch_sim as ts -from torch_sim.models.mace import MaceModel, MaceUrls - - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -dtype = torch.float32 - -loaded_model = mace_mp( - model=MaceUrls.mace_mpa_medium, - return_raw_model=True, - default_dtype=str(dtype).removeprefix("torch."), - device=str(device), -) - -SMOKE_TEST = os.getenv("CI") is not None -N_steps = 10 if SMOKE_TEST else 200 - -rng = np.random.default_rng(seed=0) - -si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) -si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) - -cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) -cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) - -fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) -fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) - -atoms_list = [si_dc, cu_dc, fe_dc] - -model = MaceModel( - model=loaded_model, - device=device, - compute_forces=True, - compute_stress=True, - dtype=dtype, - enable_cueq=False, -) - -# torch-sim batched BFGS -state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) -initial_results = model(state) -state = ts.bfgs_init(state=state, model=model, alpha=70.0) - -for _ in range(N_steps): - state = ts.bfgs_step(state=state, model=model) - -ts_final = [e.item() for e in state.energy] - -# ASE BFGS comparison -ase_calc = mace_mp( - model=MaceUrls.mace_mpa_medium, - default_dtype=str(dtype).removeprefix("torch."), - device=str(device), -) -ase_final = [] -for atoms in atoms_list: - atoms.calc = ase_calc - optimizer = ASE_BFGS(atoms, logfile=None, alpha=70.0) - optimizer.run(fmax=0.01, steps=N_steps) - ase_final.append(atoms.get_potential_energy()) - -# Results -print(f"Initial energies: {[f'{e.item():.4f}' for e in initial_results['energy']]}") -print(f"torch-sim final: {[f'{e:.4f}' for e in ts_final]}") -print(f"ASE final: {[f'{e:.4f}' for e in ase_final]}") diff --git a/examples/scripts/2_structural_optimization.py b/examples/scripts/2_structural_optimization.py index 423504994..3adf0b4f2 100644 --- a/examples/scripts/2_structural_optimization.py +++ b/examples/scripts/2_structural_optimization.py @@ -31,7 +31,7 @@ # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None -N_steps = 10 if SMOKE_TEST else 500 +N_steps = 10 if SMOKE_TEST else 100 # ============================================================================ @@ -111,7 +111,7 @@ # Run optimization for step in range(N_steps): - if step % 100 == 0: + if step % (N_steps // 5) == 0: print(f"Step {step}: Potential energy: {state.energy[0].item()} eV") state = ts.fire_step(state=state, model=lj_model, dt_max=0.01) @@ -174,7 +174,7 @@ print("\nRunning FIRE:") for step in range(N_steps): - if step % 20 == 0: + if step % (N_steps // 5) == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") state = ts.fire_step(state=state, model=model, dt_max=0.01) @@ -254,7 +254,7 @@ print("\nRunning batched unit cell gradient descent:") for step in range(N_steps): - if step % 20 == 0: + if step % (N_steps // 5) == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -308,7 +308,7 @@ print("\nRunning batched unit cell FIRE:") for step in range(N_steps): - if step % 20 == 0: + if step % (N_steps // 5) == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -360,7 +360,7 @@ print("\nRunning batched frechet cell filter with FIRE:") for step in range(N_steps): - if step % 20 == 0: + if step % (N_steps // 5) == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -386,6 +386,72 @@ print(f"Initial pressure: {initial_pressure} GPa") print(f"Final pressure: {final_pressure} GPa") +# ============================================================================ +# SECTION 7: Batched MACE L-BFGS +# ============================================================================ +print("\n" + "=" * 70) +print("SECTION 7: Batched MACE L-BFGS") +print("=" * 70) + +# Recreate structures with perturbations +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +results = model(state) +state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0) + +print("\nRunning L-BFGS:") +for step in range(N_steps): + if step % (N_steps // 5) == 0: + print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") + state = ts.lbfgs_step(state=state, model=model, max_history=100) + +print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") +print(f"Final energies: {[energy.item() for energy in state.energy]} eV") + + +# ============================================================================ +# SECTION 8: Batched MACE BFGS +# ============================================================================ +print("\n" + "=" * 70) +print("SECTION 8: Batched MACE BFGS") +print("=" * 70) + +# Recreate structures with perturbations +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +results = model(state) +state = ts.bfgs_init(state=state, model=model, alpha=70.0) + +print("\nRunning BFGS:") +for step in range(N_steps): + if step % (N_steps // 5) == 0: + print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") + state = ts.bfgs_step(state=state, model=model) + +print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV") +print(f"Final energies: {[energy.item() for energy in state.energy]} eV") + + print("\n" + "=" * 70) print("Structural optimization examples completed!") print("=" * 70) diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 6c6408d31..6644c94cb 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -213,7 +213,7 @@ def lbfgs_step( # noqa: PLR0915 alpha = alphas[state.s_history.shape[0] - 1 - i] # z <- z + s_i * (alpha - beta) coeff = (alpha - beta)[state.system_idx].unsqueeze(-1) - z = z + coeff * s_i + z = z + s_i * coeff d = -z # search direction From cea301d7facd5092e7b8f6c48bf7e77ce0a7355a Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Tue, 3 Feb 2026 21:39:01 -0800 Subject: [PATCH 05/17] correct unit cell filter and add CellLBFGS and CellBFGS states --- torch_sim/optimizers/cell_filters.py | 68 ++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 3ff0cf2de..3a061f23b 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -15,7 +15,7 @@ import torch_sim.math as fm from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers.state import FireState, OptimState +from torch_sim.optimizers.state import BFGSState, FireState, LBFGSState, OptimState from torch_sim.state import SimState @@ -309,7 +309,13 @@ def compute_cell_forces[T: AnyCellState]( state.cell_forces = cell_forces / state.cell_factor else: # Unit cell force computation - state.cell_forces = virial / state.cell_factor + # Note (AG): ASE transforms virial as: + # virial = np.linalg.solve(cur_deform_grad, virial.T).T + cur_deform_grad = deform_grad(state.reference_cell.mT, state.row_vector_cell) + virial_transformed = torch.linalg.solve( + cur_deform_grad, virial.transpose(-2, -1) + ).transpose(-2, -1) + state.cell_forces = virial_transformed / state.cell_factor CellFilterFuncs = tuple[Callable[..., None], Callable[..., None]] # (init_fn, update_fn) @@ -378,4 +384,60 @@ class CellFireState(CellOptimState, FireState): ) -AnyCellState = CellFireState | CellOptimState +@dataclass(kw_only=True) +class CellBFGSState(CellOptimState, BFGSState): + """State class for BFGS optimization with cell optimization. + + Combines BFGS position optimization with cell filter for simultaneous + optimization of atomic positions and unit cell parameters using a unified + extended coordinate space (positions + cell DOFs). + """ + + # Previous cell state for Hessian update + prev_cell_positions: torch.Tensor = field(default_factory=lambda: None) + prev_cell_forces: torch.Tensor = field(default_factory=lambda: None) + + _atom_attributes = ( + CellOptimState._atom_attributes # noqa: SLF001 + | BFGSState._atom_attributes # noqa: SLF001 + ) + _system_attributes = ( + CellOptimState._system_attributes # noqa: SLF001 + | BFGSState._system_attributes # noqa: SLF001 + | {"prev_cell_positions", "prev_cell_forces"} + ) + _global_attributes = ( + CellOptimState._global_attributes # noqa: SLF001 + | BFGSState._global_attributes # noqa: SLF001 + ) + + +@dataclass(kw_only=True) +class CellLBFGSState(CellOptimState, LBFGSState): + """State class for L-BFGS optimization with cell optimization. + + Combines L-BFGS position optimization with cell filter for simultaneous + optimization of atomic positions and unit cell parameters using a unified + extended coordinate space (positions + cell DOFs). + """ + + # Previous cell state for history update + prev_cell_positions: torch.Tensor = field(default_factory=lambda: None) + prev_cell_forces: torch.Tensor = field(default_factory=lambda: None) + + _atom_attributes = ( + CellOptimState._atom_attributes # noqa: SLF001 + | LBFGSState._atom_attributes # noqa: SLF001 + ) + _system_attributes = ( + CellOptimState._system_attributes # noqa: SLF001 + | LBFGSState._system_attributes # noqa: SLF001 + | {"prev_cell_positions", "prev_cell_forces"} + ) + _global_attributes = ( + CellOptimState._global_attributes # noqa: SLF001 + | LBFGSState._global_attributes # noqa: SLF001 + ) + + +AnyCellState = CellFireState | CellOptimState | CellBFGSState | CellLBFGSState From 9478d8509e1da1319dde949364661f2e0741594a Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Tue, 3 Feb 2026 21:40:32 -0800 Subject: [PATCH 06/17] global attr for LBFGS history --- torch_sim/optimizers/state.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index ecf414356..28b265b5a 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -122,12 +122,16 @@ class LBFGSState(OptimState): "prev_positions", } _system_attributes = OptimState._system_attributes | { # noqa: SLF001 - "s_history", - "y_history", "step_size", "alpha", "n_iter", } + # Note (AG): s_history and y_history are global attributes because they are not + # per-system indexable, so they must be copied as-is on slice. + _global_attributes = OptimState._global_attributes | { # noqa: SLF001 + "s_history", + "y_history", + } # there's no GradientDescentState, it's the same as OptimState From 9c20599dcfe7c144f7f20bd6a7b5edf5ff48cd03 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Tue, 3 Feb 2026 21:40:46 -0800 Subject: [PATCH 07/17] Add methods to init --- torch_sim/optimizers/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_sim/optimizers/__init__.py b/torch_sim/optimizers/__init__.py index 3a8d4fafe..7223e84cb 100644 --- a/torch_sim/optimizers/__init__.py +++ b/torch_sim/optimizers/__init__.py @@ -11,7 +11,12 @@ from typing import Any, Final, Literal, get_args from torch_sim.optimizers.bfgs import bfgs_init, bfgs_step -from torch_sim.optimizers.cell_filters import CellFireState, CellOptimState # noqa: F401 +from torch_sim.optimizers.cell_filters import ( # noqa: F401 + CellBFGSState, + CellFireState, + CellLBFGSState, + CellOptimState, +) from torch_sim.optimizers.fire import fire_init, fire_step from torch_sim.optimizers.gradient_descent import ( gradient_descent_init, From a9ede4756fca2d234b6c4ad28d9a6724fe7d5aff Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Tue, 3 Feb 2026 22:06:11 -0800 Subject: [PATCH 08/17] Add test comparing with ASE --- tests/test_optimizers_vs_ase.py | 337 ++++++++++++++++++++++++++++++++ 1 file changed, 337 insertions(+) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 812375684..aa509138a 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -4,7 +4,9 @@ import pytest import torch from ase.filters import FrechetCellFilter, UnitCellFilter +from ase.optimize import BFGS as ASE_BFGS from ase.optimize import FIRE +from ase.optimize import LBFGS as ASE_LBFGS from pymatgen.analysis.structure_matcher import StructureMatcher import torch_sim as ts @@ -316,3 +318,338 @@ def test_optimizer_vs_ase_parametrized( tolerances=tolerances, test_id_prefix=test_id_prefix, ) + + +# TODO (AG): Can we merge these tests with the FIRE tests? + + +@pytest.mark.parametrize( + ( + "sim_state_fixture_name", + "cell_filter", + "ase_filter_cls", + "checkpoints", + "force_tol", + "tolerances", + "test_id_prefix", + ), + [ + ( + "rattled_sio2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS SiO2 (Frechet)", + ), + ( + "osn2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS OsN2 (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "BFGS Triclinic Al (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "BFGS Triclinic Al (UnitCell)", + ), + ( + "rattled_sio2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS SiO2 (UnitCell)", + ), + ( + "osn2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "BFGS OsN2 (UnitCell)", + ), + ], +) +def test_bfgs_vs_ase_parametrized( + sim_state_fixture_name: str, + cell_filter: ts.CellFilter, + ase_filter_cls: type, + checkpoints: list[int], + force_tol: float, + tolerances: dict[str, float], + test_id_prefix: str, + ts_mace_mpa: MaceModel, + ase_mace_mpa: "MACECalculator", + request: pytest.FixtureRequest, +) -> None: + """Compare torch-sim BFGS with ASE BFGS at multiple checkpoints.""" + pytest.importorskip("mace") + device = ts_mace_mpa.device + + initial_sim_state = request.getfixturevalue(sim_state_fixture_name) + state = initial_sim_state.clone() + + ase_atoms = ts.io.state_to_atoms( + initial_sim_state.clone().to(dtype=DTYPE, device=device) + )[0] + ase_atoms.calc = ase_mace_mpa + filtered_ase_atoms = ase_filter_cls(ase_atoms) + ase_optimizer = ASE_BFGS(filtered_ase_atoms, logfile=None, alpha=70.0) + + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=True + ) + + # Compare initial state + results = ts_mace_mpa(state) + ts_initial = state.clone() + ts_initial.forces = results["forces"] + ts_initial.energy = results["energy"] + ase_mace_mpa.calculate(ase_atoms) + _compare_ase_and_ts_states( + ts_initial, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Initial)" + ) + + last_step = 0 + for checkpoint in checkpoints: + steps = checkpoint - last_step + if steps > 0: + state = ts.optimize( + system=state, + model=ts_mace_mpa, + optimizer=ts.Optimizer.bfgs, + max_steps=steps, + convergence_fn=convergence_fn, + steps_between_swaps=1, + init_kwargs=dict(cell_filter=cell_filter), + ) + ase_optimizer.run(fmax=force_tol, steps=steps) + + _compare_ase_and_ts_states( + state, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Step {checkpoint})" + ) + last_step = checkpoint + + +# TODO (AG): Can we merge these tests with the FIRE tests? + + +@pytest.mark.parametrize( + ( + "sim_state_fixture_name", + "cell_filter", + "ase_filter_cls", + "checkpoints", + "force_tol", + "tolerances", + "test_id_prefix", + ), + [ + ( + "rattled_sio2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS SiO2 (Frechet)", + ), + ( + "osn2_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS OsN2 (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.frechet, + FrechetCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "LBFGS Triclinic Al (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.01, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, + "LBFGS Triclinic Al (UnitCell)", + ), + ( + "rattled_sio2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 33, 66, 100], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS SiO2 (UnitCell)", + ), + ( + "osn2_sim_state", + ts.CellFilter.unit, + UnitCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "LBFGS OsN2 (UnitCell)", + ), + ], +) +def test_lbfgs_vs_ase_parametrized( + sim_state_fixture_name: str, + cell_filter: ts.CellFilter, + ase_filter_cls: type, + checkpoints: list[int], + force_tol: float, + tolerances: dict[str, float], + test_id_prefix: str, + ts_mace_mpa: MaceModel, + ase_mace_mpa: "MACECalculator", + request: pytest.FixtureRequest, +) -> None: + """Compare torch-sim L-BFGS with ASE LBFGS at multiple checkpoints.""" + pytest.importorskip("mace") + device = ts_mace_mpa.device + + initial_sim_state = request.getfixturevalue(sim_state_fixture_name) + state = initial_sim_state.clone() + + ase_atoms = ts.io.state_to_atoms( + initial_sim_state.clone().to(dtype=DTYPE, device=device) + )[0] + ase_atoms.calc = ase_mace_mpa + filtered_ase_atoms = ase_filter_cls(ase_atoms) + ase_optimizer = ASE_LBFGS(filtered_ase_atoms, logfile=None, alpha=70.0, damping=1.0) + + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=True + ) + + # Compare initial state + results = ts_mace_mpa(state) + ts_initial = state.clone() + ts_initial.forces = results["forces"] + ts_initial.energy = results["energy"] + ase_mace_mpa.calculate(ase_atoms) + _compare_ase_and_ts_states( + ts_initial, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Initial)" + ) + + last_step = 0 + for checkpoint in checkpoints: + steps = checkpoint - last_step + if steps > 0: + state = ts.optimize( + system=state, + model=ts_mace_mpa, + optimizer=ts.Optimizer.lbfgs, + max_steps=steps, + convergence_fn=convergence_fn, + steps_between_swaps=1, + init_kwargs=dict(cell_filter=cell_filter, alpha=70.0, step_size=1.0), + max_step=0.2, + ) + ase_optimizer.run(fmax=force_tol, steps=steps) + + _compare_ase_and_ts_states( + state, filtered_ase_atoms, tolerances, f"{test_id_prefix} (Step {checkpoint})" + ) + last_step = checkpoint From 802c966d98441c7ca7e032a630186425029f52d7 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 01:03:54 -0800 Subject: [PATCH 09/17] Update batched bfgs --- tests/test_autobatching.py | 217 +++++++++++++ torch_sim/optimizers/bfgs.py | 556 +++++++++++++++++++++++++--------- torch_sim/optimizers/state.py | 34 ++- torch_sim/state.py | 31 +- 4 files changed, 674 insertions(+), 164 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7ec870d3c..1baecc30d 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -605,3 +605,220 @@ def test_in_flight_max_iterations( # Verify iteration_count tracking for idx in range(len(states)): assert batcher.iteration_count[idx] == max_iterations + + +@pytest.mark.parametrize( + "num_steps_per_batch", + [ + 5, # At 5 steps, not every state will converge before the next batch. + 10, # At 10 steps, all states will converge before the next batch + ], +) +def test_in_flight_with_bfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + num_steps_per_batch: int, +) -> None: + """Test InFlightAutoBatcher with BFGS optimizer (matching FIRE test structure).""" + si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_bfgs_state = ts.bfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5 + bfgs_states = [state.clone() for state in bfgs_states] + for state in bfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = InFlightAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=6000, + ) + batcher.load_states(bfgs_states) + + def convergence_fn(state: ts.BFGSState) -> torch.Tensor: + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=torch.float64 + ) + max_forces = state.forces.norm(dim=1) + system_wise_max_force = system_wise_max_force.scatter_reduce( + dim=0, index=state.system_idx, src=max_forces, reduce="amax" + ) + return system_wise_max_force < 5e-1 + + all_completed_states, convergence_tensor = [], None + while True: + state, completed_states = batcher.next_batch(state, convergence_tensor) + + all_completed_states.extend(completed_states) + if state is None: + break + + for _ in range(num_steps_per_batch): + state = ts.bfgs_step(state=state, model=lj_model) + convergence_tensor = convergence_fn(state) + + assert len(all_completed_states) == len(bfgs_states) + + +def test_binning_auto_batcher_with_bfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with BFGS optimizer (matching FIRE test structure).""" + si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_bfgs_state = ts.bfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5 + bfgs_states = [state.clone() for state in bfgs_states] + for state in bfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = BinningAutoBatcher( + model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000 + ) + batcher.load_states(bfgs_states) + + all_finished_states: list[ts.SimState] = [] + total_batches = 0 + for batch, _ in batcher: + total_batches += 1 # noqa: SIM113 + for _ in range(5): + batch = ts.bfgs_step(state=batch, model=lj_model) + all_finished_states.extend(batch.split()) + + assert len(all_finished_states) == len(bfgs_states) + + +def _group_states_by_size( + states: list[ts.SimState], +) -> list[list[tuple[int, ts.SimState]]]: + """Group states by n_atoms, preserving original indices for order restoration. + + Used for L-BFGS which requires same-sized systems in each batch due to + history tensor shapes being dependent on n_atoms. + """ + from itertools import groupby + + indexed_states = list(enumerate(states)) + sorted_states = sorted(indexed_states, key=lambda x: x[1].n_atoms) + groups = [] + for _, group in groupby(sorted_states, key=lambda x: x[1].n_atoms): + groups.append(list(group)) + return groups + + +@pytest.mark.skip( + reason="L-BFGS with InFlightAutoBatcher has a known issue: history tensors " + "become misaligned when systems are dynamically removed on convergence. " + "Use BinningAutoBatcher instead for L-BFGS." +) +@pytest.mark.parametrize( + "num_steps_per_batch", + [ + 5, # At 5 steps, not every state will converge before the next batch. + 10, # At 10 steps, all states will converge before the next batch + ], +) +def test_in_flight_with_lbfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + num_steps_per_batch: int, +) -> None: + """Test InFlightAutoBatcher with L-BFGS optimizer (matching FIRE test structure).""" + si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_lbfgs_state = ts.lbfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5 + lbfgs_states = [state.clone() for state in lbfgs_states] + for state in lbfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = InFlightAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=6000, + ) + batcher.load_states(lbfgs_states) + + def convergence_fn(state: ts.LBFGSState) -> torch.Tensor: + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=torch.float64 + ) + max_forces = state.forces.norm(dim=1) + system_wise_max_force = system_wise_max_force.scatter_reduce( + dim=0, index=state.system_idx, src=max_forces, reduce="amax" + ) + return system_wise_max_force < 5e-1 + + all_completed_states, convergence_tensor = [], None + while True: + state, completed_states = batcher.next_batch(state, convergence_tensor) + + all_completed_states.extend(completed_states) + if state is None: + break + + for _ in range(num_steps_per_batch): + state = ts.lbfgs_step(state=state, model=lj_model) + convergence_tensor = convergence_fn(state) + + assert len(all_completed_states) == len(lbfgs_states) + + +def test_binning_auto_batcher_with_lbfgs( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with L-BFGS optimizer (matching FIRE test structure).""" + si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) + fe_lbfgs_state = ts.lbfgs_init( + fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit + ) + + lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5 + lbfgs_states = [state.clone() for state in lbfgs_states] + for state in lbfgs_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + # Group by size and process each group separately + size_groups = _group_states_by_size(lbfgs_states) + all_finished_with_indices: list[tuple[int, ts.SimState]] = [] + total_batches = 0 + + for group in size_groups: + original_indices, group_states = zip(*group, strict=True) + group_states_list = list(group_states) + + batcher = BinningAutoBatcher( + model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000 + ) + batcher.load_states(group_states_list) + + finished_states = [] + for batch, _ in batcher: + total_batches += 1 + for _ in range(5): + batch = ts.lbfgs_step(state=batch, model=lj_model) + finished_states.extend(batch.split()) + + restored = batcher.restore_original_order(finished_states) + for idx, finished_state in zip(original_indices, restored, strict=True): + all_finished_with_indices.append((idx, finished_state)) + + # Sort by original index to restore order + all_finished_with_indices.sort(key=lambda x: x[0]) + all_finished_states = [s for _, s in all_finished_with_indices] + + assert len(all_finished_states) == len(lbfgs_states) + for restored, original in zip(all_finished_states, lbfgs_states, strict=True): + assert torch.all(restored.atomic_numbers == original.atomic_numbers) diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index bb7021c5d..e23a65c94 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -7,19 +7,28 @@ The implementation handles batches of systems with different numbers of atoms by padding vectors to the maximum number of atoms in the batch. The Hessian matrices are similarly padded to shape (n_systems, 3*max_atoms, 3*max_atoms). + +Note: When cell_filter is active, forces are transformed using the deformation gradient +to work in the same scaled coordinate space as ASE's UnitCellFilter/FrechetCellFilter. +The prev_forces and prev_positions are stored in the scaled/fractional space to match +ASE's behavior. """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch +import torch_sim as ts +from torch_sim.optimizers import cell_filters +from torch_sim.optimizers.cell_filters import frechet_cell_filter_init from torch_sim.state import SimState from torch_sim.typing import StateDict if TYPE_CHECKING: from torch_sim.models.interface import ModelInterface - from torch_sim.optimizers import BFGSState + from torch_sim.optimizers import BFGSState, CellBFGSState + from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs def _get_atom_indices_per_system( @@ -76,75 +85,174 @@ def bfgs_init( *, max_step: float = 0.2, alpha: float = 70.0, -) -> "BFGSState": + cell_filter: "CellFilter | CellFilterFuncs | None" = None, + **filter_kwargs: Any, +) -> "BFGSState | CellBFGSState": """Create an initial BFGSState. - Initializes the Hessian as Identity * alpha. + Initializes the Hessian as Identity matrix * alpha. + + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (max_atoms) + D = 3*M (position DOFs) + D_ext = 3*M + 9 (extended DOFs with cell) Args: state: Input state model: Model max_step: Maximum step size (Angstrom) alpha: Initial Hessian stiffness (eV/A^2) + cell_filter: Filter for cell optimization (None for position-only optimization) + **filter_kwargs: Additional arguments passed to cell filter initialization Returns: - BFGSState + BFGSState or CellBFGSState if cell_filter is provided """ - from torch_sim.optimizers import BFGSState + from torch_sim.optimizers import BFGSState, CellBFGSState tensor_args = {"device": model.device, "dtype": model.dtype} if not isinstance(state, SimState): state = SimState(**state) - n_systems = state.n_systems + n_systems = state.n_systems # S - counts = state.n_atoms_per_system - max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 - atom_idx = _get_atom_indices_per_system(state.system_idx, n_systems) + counts = state.n_atoms_per_system # [S] + global_max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 # M + # Per-system max_atoms for padding/unpadding support + max_atoms = counts.clone() # [S] - each system's atom count + atom_idx = _get_atom_indices_per_system(state.system_idx, n_systems) # [N] model_output = model(state) - energy = model_output["energy"] - forces = model_output["forces"] - stress = model_output["stress"] - - # shape: (n_systems, 3*max_atoms, 3*max_atoms) - dim = 3 * max_atoms - hessian = torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha - - alpha_t = torch.full((n_systems,), alpha, **tensor_args) - max_step_t = torch.full((n_systems,), max_step, **tensor_args) - n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) - - return BFGSState( - positions=state.positions.clone(), - masses=state.masses.clone(), - cell=state.cell.clone(), - atomic_numbers=state.atomic_numbers.clone(), - forces=forces, - energy=energy, - stress=stress, - hessian=hessian, - prev_forces=forces.clone(), - prev_positions=state.positions.clone(), - alpha=alpha_t, - max_step=max_step_t, - n_iter=n_iter, - atom_idx_in_system=atom_idx, - max_atoms=max_atoms, - # passed to __post_init__ - system_idx=state.system_idx.clone(), - pbc=state.pbc, - ) - - -def bfgs_step( - state: "BFGSState", + energy = model_output["energy"] # [S] + forces = model_output["forces"] # [N, 3] + stress = model_output.get("stress") # [S, 3, 3] or None + + alpha_t = torch.full((n_systems,), alpha, **tensor_args) # [S] + max_step_t = torch.full((n_systems,), max_step, **tensor_args) # [S] + n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) # [S] + + if cell_filter is not None: + # Extended Hessian: (3*global_max_atoms + 9) x (3*global_max_atoms + 9) + # The extra 9 DOFs are for cell parameters (3x3 matrix flattened) + dim = 3 * global_max_atoms + (3 * 3) # D_ext + hessian = ( + torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha + ) # [S, D_ext, D_ext] + + cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) + + # Note (AG): At initialization, deform_grad is identity, so we have: + # fractional = Cartesian / cell and scaled forces = forces @ I = forces + # For ASE compatibility, we need to store prev_positions as fractional coords + # and prev_forces as scaled forces + + # Get initial deform_grad (identity at start since reference_cell = current_cell) + reference_cell = state.cell.clone() # [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + reference_cell.mT, state.cell.mT + ) # [S, 3, 3] + + # Initial fractional positions = solve(deform_grad, positions) = positions + # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Initial scaled forces = forces @ deform_grad = forces + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] + scaled_forces = torch.bmm( + forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) + + common_args = { + "positions": state.positions.clone(), # [N, 3] + "masses": state.masses.clone(), # [N] + "cell": state.cell.clone(), # [S, 3, 3] + "atomic_numbers": state.atomic_numbers.clone(), # [N] + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "hessian": hessian, # [S, D_ext, D_ext] + # Note (AG): Store fractional positions and scaled forces + # for ASE compatibility + "prev_forces": scaled_forces, # [N, 3] (scaled) + "prev_positions": frac_positions, # [N, 3] (fractional) + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # scalar M + "system_idx": state.system_idx.clone(), # [N] + "pbc": state.pbc, # [S, 3] + "reference_cell": reference_cell, # [S, 3, 3] + "cell_filter": cell_filter_funcs, + } + + cell_state = CellBFGSState(**common_args) + + # Initialize cell-specific attributes (cell_positions, cell_forces, etc.) + # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] + init_fn(cell_state, model, **filter_kwargs) + + # Store prev_cell_positions and prev_cell_forces for Hessian update + cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] + cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + + return cell_state + + # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms + dim = 3 * global_max_atoms # D + hessian = ( + torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha + ) # [S, D, D] + + common_args = { + "positions": state.positions.clone(), # [N, 3] + "masses": state.masses.clone(), # [N] + "cell": state.cell.clone(), # [S, 3, 3] + "atomic_numbers": state.atomic_numbers.clone(), # [N] + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "hessian": hessian, # [S, D, D] + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # scalar M + "system_idx": state.system_idx.clone(), # [N] + "pbc": state.pbc, # [S, 3] + } + + return BFGSState(**common_args) + + +def bfgs_step( # noqa: C901, PLR0915 + state: "BFGSState | CellBFGSState", model: "ModelInterface", -) -> "BFGSState": +) -> "BFGSState | CellBFGSState": """Perform one BFGS optimization step. - Updates the Hessian estimate and moves atoms. + Updates the Hessian estimate and moves atoms. If state is a CellBFGSState, + forces are transformed using the deformation gradient to work in the same + scaled coordinate space as ASE's cell filters (matching FIRE's approach). + + For cell optimization, prev_positions are stored as fractional coordinates + and prev_forces as scaled forces, exactly matching ASE's pos0/forces0. + + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (max_atoms) + D = 3*M (position DOFs) + D_ext = 3*M + 9 (extended DOFs with cell) Args: state: Current optimization state @@ -153,130 +261,286 @@ def bfgs_step( Returns: Updated state """ - eps = 1e-7 + from torch_sim.optimizers import CellBFGSState - # Pack flat tensors into dense batched tensors - # shape: (n_systems, max_atoms * 3) - pos_new = _pad_to_dense( - state.positions, - state.system_idx, - state.atom_idx_in_system, - state.n_systems, - state.max_atoms, - ).reshape(state.n_systems, -1) - - forces_new = _pad_to_dense( - state.forces, - state.system_idx, - state.atom_idx_in_system, - state.n_systems, - state.max_atoms, - ).reshape(state.n_systems, -1) - - pos_old = _pad_to_dense( - state.prev_positions, - state.system_idx, - state.atom_idx_in_system, - state.n_systems, - state.max_atoms, - ).reshape(state.n_systems, -1) - - forces_old = _pad_to_dense( - state.prev_forces, - state.system_idx, - state.atom_idx_in_system, - state.n_systems, - state.max_atoms, - ).reshape(state.n_systems, -1) + # Note (AG): eps kept same as ASE's BFGS. + eps = 1e-7 + is_cell_state = isinstance(state, CellBFGSState) + + # Derive global_max_atoms from hessian shape + hessian_dim = state.hessian.shape[1] + global_max_atoms = (hessian_dim - 9) // 3 if is_cell_state else hessian_dim // 3 + + if is_cell_state: + # Get current deformation gradient + # reference_cell.mT: [S, 3, 3], row_vector_cell: [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + + # Transform forces to scaled coordinates + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] + forces_scaled = torch.bmm( + state.forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) # [N, 3] + + # Current fractional positions + # positions: [N, 3] -> frac_positions: [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Pack into dense tensors [N, 3] -> [S, M, 3] -> [S, D] + # For cell state, prev_positions is already fractional (stored that way) + # prev_forces is already scaled + # Note (AG): Optimization potential here. + forces_new = _pad_to_dense( + forces_scaled, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + forces_old = _pad_to_dense( + state.prev_forces, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_new = _pad_to_dense( + frac_positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_old = _pad_to_dense( + state.prev_positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + # Extend with cell DOFs: [S, 3, 3] -> [S, 9] + cell_pos_new = state.cell_positions.reshape(state.n_systems, 9) # [S, 9] + cell_forces_new = state.cell_forces.reshape(state.n_systems, 9) # [S, 9] + cell_pos_old = state.prev_cell_positions.reshape(state.n_systems, 9) # [S, 9] + cell_forces_old = state.prev_cell_forces.reshape(state.n_systems, 9) # [S, 9] + + # Concatenate: extended = [positions, cell_positions] + # [S, D] + [S, 9] -> [S, D_ext] + pos_new = torch.cat([pos_new, cell_pos_new], dim=1) # [S, D_ext] + forces_new = torch.cat([forces_new, cell_forces_new], dim=1) # [S, D_ext] + pos_old = torch.cat([pos_old, cell_pos_old], dim=1) # [S, D_ext] + forces_old = torch.cat([forces_old, cell_forces_old], dim=1) # [S, D_ext] + else: + forces_scaled = state.forces # [N, 3] + + # Pack into dense tensors [N, 3] -> [S, M, 3] -> [S, D] + forces_new = _pad_to_dense( + state.forces, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + forces_old = _pad_to_dense( + state.prev_forces, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_new = _pad_to_dense( + state.positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] + + pos_old = _pad_to_dense( + state.prev_positions, # [N, 3] + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + global_max_atoms, + ).reshape(state.n_systems, -1) # [S, D] # Calculate displacements and force changes - # dpos: (n_systems, max_atoms * 3) - dpos = pos_new - pos_old - dforces = -(forces_new - forces_old) + # dim = D or D_ext depending on cell_state + dpos = pos_new - pos_old # [S, dim] + dforces = forces_new - forces_old # [S, dim] # Identify systems with significant movement - max_disp = torch.max(torch.abs(dpos), dim=1).values - update_mask = max_disp >= eps + max_disp = torch.max(torch.abs(dpos), dim=1).values # [S] + update_mask = max_disp >= eps # [S] bool - # Update Hessian for active systems + # Update Hessian for active systems (BFGS update formula) if update_mask.any(): idx = update_mask - H = state.hessian[idx] - - # shape: (n_active, D, 1) - dp = dpos[idx].unsqueeze(2) - df = dforces[idx].unsqueeze(2) # noqa: PD901 + H = state.hessian[idx] # [S_active, dim, dim] - # shape: (n_active, 1) - a = torch.bmm(dp.transpose(1, 2), df).squeeze(2) + dp = dpos[idx].unsqueeze(2) # [S_active, dim, 1] + df = dforces[idx].unsqueeze(2) # [S_active, dim, 1] # noqa: PD901 - # shape: (n_active, D, 1) - dg = torch.bmm(H, dp) + # a = dp^T @ df: [S_active, 1, dim] @ [S_active, dim, 1] -> [S_active, 1, 1] + a = torch.bmm(dp.transpose(1, 2), df).squeeze(2) # [S_active, 1] + # dg = H @ dp: [S_active, dim, dim] @ [S_active, dim, 1] -> [S_active, dim, 1] + dg = torch.bmm(H, dp) # [S_active, dim, 1] + # b = dp^T @ dg: [S_active, 1, dim] @ [S_active, dim, 1] -> [S_active, 1, 1] + b = torch.bmm(dp.transpose(1, 2), dg).squeeze(2) # [S_active, 1] - # shape: (n_active, 1) - b = torch.bmm(dp.transpose(1, 2), dg).squeeze(2) - - # Rank-2 update - # shape: (n_active, D, D) + # term1 = df @ df^T / a: [S_active, dim, dim] term1 = torch.bmm(df, df.transpose(1, 2)) / (a.unsqueeze(2) + 1e-30) + # term2 = dg @ dg^T / b: [S_active, dim, dim] term2 = torch.bmm(dg, dg.transpose(1, 2)) / (b.unsqueeze(2) + 1e-30) - state.hessian[idx] = H - term1 - term2 + state.hessian[idx] = H - term1 - term2 # [S_active, dim, dim] # Calculate step direction using eigendecomposition - # gradient: (n_systems, D, 1) - # Step p = H^-1 * F - direction = forces_new.unsqueeze(2) - - # omega: (n_systems, D), V: (n_systems, D, D) - omega, V = torch.linalg.eigh(state.hessian) - - # shape: (n_systems, 1, D) - abs_omega = torch.abs(omega).unsqueeze(1) - abs_omega = torch.where(abs_omega < 1e-30, torch.ones_like(abs_omega), abs_omega) - - # Project direction onto eigenvectors and scale - # shape: (n_systems, D, 1) - vt_g = torch.bmm(V.transpose(1, 2), direction) - scaled = vt_g / abs_omega.transpose(1, 2) - - # Transform back to original basis - # shape: (n_systems, D) - step_dense = torch.bmm(V, scaled).squeeze(2) + # step = V @ (|omega|^-1) @ V^T @ forces (pseudo-inverse via eigendecomposition) + # Note (AG): We use eigendecomposition rather than directly inverting H so we can + # take the absolute value of eigenvalues (|omega|). This ensures the step is always + # in a descent direction even if the Hessian approximation has negative eigenvalues. + + # Size-binned eigendecomposition: group systems by actual Hessian size + hessian_dim = state.hessian.shape[1] + step_dense = torch.zeros( + state.n_systems, hessian_dim, device=state.device, dtype=state.dtype + ) # [S, dim] + + # Get unique sizes and process each group with batched eigendecomp + # TODO(AG): If we sort and get the sizes before hand we can reduce the + # python loop overhead. + unique_sizes = state.max_atoms.unique() + + for size in unique_sizes: + actual_dim = int(3 * size.item()) + (9 if is_cell_state else 0) + mask = state.max_atoms == size # [S] bool - systems with this size + + # Extract actual-sized Hessians and forces for this group + H_group = state.hessian[mask, :actual_dim, :actual_dim] # [G, d, d] + f_group = forces_new[mask, :actual_dim] # [G, d] + + # Batched eigendecomposition on actual size (no padding overhead) + omega, V = torch.linalg.eigh(H_group) # omega: [G, d], V: [G, d, d] + abs_omega = torch.abs(omega).clamp(min=1e-30) # [G, d] + + # Compute step: V @ (V^T @ f / |omega|) + vt_f = torch.bmm(V.transpose(1, 2), f_group.unsqueeze(2)) # [G, d, 1] + step_group = torch.bmm(V, vt_f / abs_omega.unsqueeze(2)).squeeze(2) # [G, d] + + # Place results back into step_dense (padded to hessian_dim) + indices = mask.nonzero(as_tuple=True)[0] + step_dense[indices, :actual_dim] = step_group + + # Split step into position and cell components + atom_dim = 3 * global_max_atoms # D + if is_cell_state: + step_pos = step_dense[:, :atom_dim] # [S, D] + step_cell = step_dense[:, atom_dim:] # [S, 9] + else: + step_pos = step_dense # [S, D] # Scale step if it exceeds max_step - # step_atoms: (n_systems, max_atoms, 3) - step_atoms = step_dense.view(state.n_systems, state.max_atoms, 3) - # atom_norms: (n_systems, max_atoms) - atom_norms = torch.norm(step_atoms, dim=2) - - # max_disp_per_sys: (n_systems,) - max_disp_per_sys = torch.max(atom_norms, dim=1).values - - scale = torch.ones_like(max_disp_per_sys) - needs_scale = max_disp_per_sys > state.max_step + step_atoms = step_pos.view(state.n_systems, global_max_atoms, 3) # [S, M, 3] + atom_norms = torch.norm(step_atoms, dim=2) # [S, M] + + if is_cell_state: + step_cell_reshaped = step_cell.view(state.n_systems, 3, 3) # [S, 3, 3] + cell_norms = torch.norm(step_cell_reshaped, dim=2) # [S, 3] + all_norms = torch.cat([atom_norms, cell_norms], dim=1) # [S, M+3] + max_disp_per_sys = torch.max(all_norms, dim=1).values # [S] + else: + max_disp_per_sys = torch.max(atom_norms, dim=1).values # [S] + + scale = torch.ones_like(max_disp_per_sys) # [S] + needs_scale = max_disp_per_sys > state.max_step # [S] bool scale[needs_scale] = state.max_step[needs_scale] / ( max_disp_per_sys[needs_scale] + 1e-30 ) - # shape: (n_systems, D) - step_dense = step_dense * scale.unsqueeze(1) + step_pos = step_pos * scale.unsqueeze(1) # [S, D] + if is_cell_state: + step_cell = step_cell * scale.unsqueeze(1) # [S, 9] - # Unpack dense step back to flat valid atoms - flat_step = step_dense.view(state.n_systems, state.max_atoms, 3)[ + # Unpack dense step to flat: [S, M, 3] -> [N, 3] + flat_step = step_pos.view(state.n_systems, global_max_atoms, 3)[ state.system_idx, state.atom_idx_in_system - ] - - new_positions = state.positions + flat_step + ] # [N, 3] + + # Save previous state for next Hessian update + # For cell state: store fractional positions and scaled forces (ASE convention) + if is_cell_state: + state.prev_positions = frac_positions.clone() # [N, 3] (fractional) + state.prev_forces = forces_scaled.clone() # [N, 3] (scaled) + state.prev_cell_positions = state.cell_positions.clone() # [S, 3, 3] + state.prev_cell_forces = state.cell_forces.clone() # [S, 3, 3] + + # Apply cell step: [S, 9] -> [S, 3, 3] + dr_cell = step_cell.view(state.n_systems, 3, 3) # [S, 3, 3] + cell_positions_new = state.cell_positions + dr_cell # [S, 3, 3] + state.cell_positions = cell_positions_new # [S, 3, 3] + + # Determine if Frechet filter + init_fn, _step_fn = state.cell_filter + is_frechet = init_fn is frechet_cell_filter_init + + if is_frechet: + # Frechet: deform_grad = exp(cell_positions / cell_factor) + cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) + deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3] + deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3] + else: + # UnitCell: deform_grad = cell_positions / cell_factor + cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) + deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3] + + # Update cell: new_cell = reference_cell @ deform_grad^T + # reference_cell.mT: [S, 3, 3], deform_grad_new: [S, 3, 3] + state.row_vector_cell = torch.bmm( + state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + ) # [S, 3, 3] + + # Apply position step in fractional space, then convert to Cartesian + new_frac = frac_positions + flat_step # [N, 3] + + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + # new_positions = new_frac @ deform_grad^T + new_positions = torch.bmm( + new_frac.unsqueeze(1), # [N, 1, 3] + new_deform_grad[state.system_idx].transpose(-2, -1), # [N, 3, 3] + ).squeeze(1) # [N, 3] + state.positions = new_positions # [N, 3] + else: + state.prev_positions = state.positions.clone() # [N, 3] + state.prev_forces = state.forces.clone() # [N, 3] + state.positions = state.positions + flat_step # [N, 3] + + # Evaluate new forces and energy + model_output = model(state) + state.forces = model_output["forces"] # [N, 3] + state.energy = model_output["energy"] # [S] + if "stress" in model_output: + state.stress = model_output["stress"] # [S, 3, 3] - state.prev_positions = state.positions.clone() - state.prev_forces = state.forces.clone() - state.positions = new_positions + # Update cell forces for next step + # Update cell forces for cell state: [S, 3, 3] + if is_cell_state: + cell_filters.compute_cell_forces(model_output, state) - model_output = model(state) - state.forces = model_output["forces"] - state.energy = model_output["energy"] - state.stress = model_output["stress"] state.n_iter += 1 return state diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 28b265b5a..872b3b587 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -1,6 +1,7 @@ """Optimizer state classes.""" from dataclasses import dataclass +from typing import ClassVar import torch @@ -55,17 +56,21 @@ class BFGSState(OptimState): """State for batched BFGS optimization. Stores the state needed to run a batched BFGS optimizer that maintains - an approximate Hessian or inverse Hessian. + an approximate Hessian matrix. Attributes: - hessian: Hessian matrix [n_systems, 3*max_atoms, 3*max_atoms] - prev_forces: Previous-step forces [n_atoms, 3] - prev_positions: Previous-step positions [n_atoms, 3] - alpha: Initial Hessian scale [n_systems] - max_step: Maximum step size [n_systems] + hessian: Hessian matrix [n_systems, dim, dim] where dim = 3*max_atoms + for position-only or 3*max_atoms + 9 with cell filter. + May be padded when systems have different sizes. + prev_forces: Previous-step forces [n_atoms, 3]. For cell filter, + these are scaled forces (forces @ deform_grad) for ASE compatibility. + prev_positions: Previous-step positions [n_atoms, 3]. For cell filter, + these are fractional coordinates for ASE compatibility. + alpha: Initial Hessian scale (stiffness) [n_systems] + max_step: Maximum step size per atom [n_systems] n_iter: Per-system iteration counter [n_systems] (int32) atom_idx_in_system: Index of each atom within its system [n_atoms] - max_atoms: Maximum number of atoms in any system (int) + max_atoms: Atoms per system [n_systems] - used for size-binned eigendecomp """ hessian: torch.Tensor @@ -75,7 +80,7 @@ class BFGSState(OptimState): max_step: torch.Tensor n_iter: torch.Tensor atom_idx_in_system: torch.Tensor - max_atoms: int + max_atoms: torch.Tensor # Changed from int to Tensor for padding support _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 "prev_forces", @@ -89,6 +94,8 @@ class BFGSState(OptimState): "n_iter", "max_atoms", } + # Attributes that need padding when concatenating different-sized systems + _padded_system_attributes: ClassVar[set[str]] = {"hessian"} @dataclass(kw_only=True) @@ -101,11 +108,14 @@ class LBFGSState(OptimState): systems via `system_idx`. Attributes: - prev_forces: Previous-step forces [n_atoms, 3] - prev_positions: Previous-step positions [n_atoms, 3] - s_history: Displacement history [h, n_atoms, 3] - y_history: Gradient-diff history [h, n_atoms, 3] + prev_forces: Previous-step forces [n_atoms, 3]. For cell filter, + these are scaled forces (forces @ deform_grad) for ASE compatibility. + prev_positions: Previous-step positions [n_atoms, 3]. For cell filter, + these are fractional coordinates for ASE compatibility. + s_history: Displacement history [h, n_atoms, 3] (global, not per-system) + y_history: Gradient-diff history [h, n_atoms, 3] (global, not per-system) step_size: Per-system fixed step size [n_systems] + alpha: Initial inverse Hessian scale (stiffness) [n_systems] n_iter: Per-system iteration counter [n_systems] (int32) """ diff --git a/torch_sim/state.py b/torch_sim/state.py index 318454c79..f7c7a4c84 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -848,6 +848,12 @@ def _split_state[T: SimState](state: T) -> list[T]: zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) for sys_idx in range(n_systems): + # Build per-system attributes (padded attributes stay padded for consistency) + per_system_dict = { + attr_name: split_per_system[attr_name][sys_idx] + for attr_name in split_per_system + } + system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( @@ -858,11 +864,8 @@ def _split_state[T: SimState](state: T) -> list[T]: attr_name: split_per_atom[attr_name][sys_idx] for attr_name in split_per_atom }, - # Add the split per-system attributes - **{ - attr_name: split_per_system[attr_name][sys_idx] - for attr_name in split_per_system - }, + # Add the split per-system attributes (with unpadding applied) + **per_system_dict, # Add the global attributes **global_attrs, } @@ -1039,10 +1042,26 @@ def concatenate_states[T: SimState]( # noqa: C901 # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) + # Get padded attributes if defined on the state class + padded_attrs = getattr(first_state, "_padded_system_attributes", set()) + for prop, tensors in per_system_tensors.items(): # if tensors: if isinstance(tensors[0], torch.Tensor): - concatenated[prop] = torch.cat(tensors, dim=0) + if prop in padded_attrs: + # Pad tensors to max size before concatenating + # Assumes padding is needed on last two dimensions (e.g., hessian) + max_size = max(t.shape[-1] for t in tensors) + padded_tensors = [] + for t in tensors: + if t.shape[-1] < max_size: + pad_size = max_size - t.shape[-1] + # Pad last two dimensions (for 3D tensors like hessian) + t = torch.nn.functional.pad(t, (0, pad_size, 0, pad_size)) + padded_tensors.append(t) + concatenated[prop] = torch.cat(padded_tensors, dim=0) + else: + concatenated[prop] = torch.cat(tensors, dim=0) else: # Non-tensor attributes, take first one (they should all be identical) concatenated[prop] = tensors[0] From 8983439cef38bbb8fd8a07b09a73a66df7eb92cc Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 03:03:07 -0800 Subject: [PATCH 10/17] charge and spin --- torch_sim/__init__.py | 2 ++ torch_sim/optimizers/bfgs.py | 4 +++ torch_sim/optimizers/state.py | 17 +++++++---- torch_sim/state.py | 53 +++++++++++++++++++++++++++-------- 4 files changed, 59 insertions(+), 17 deletions(-) diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index 36bde051b..b56973170 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -69,8 +69,10 @@ ) from torch_sim.optimizers.cell_filters import ( CELL_FILTER_REGISTRY, + CellBFGSState, CellFilter, CellFireState, + CellLBFGSState, CellOptimState, get_cell_filter, ) diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index e23a65c94..3bdfbb0be 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -191,6 +191,8 @@ def bfgs_init( "pbc": state.pbc, # [S, 3] "reference_cell": reference_cell, # [S, 3, 3] "cell_filter": cell_filter_funcs, + "charge": state.charge, # preserve charge + "spin": state.spin, # preserve spin } cell_state = CellBFGSState(**common_args) @@ -229,6 +231,8 @@ def bfgs_init( "max_atoms": max_atoms, # scalar M "system_idx": state.system_idx.clone(), # [N] "pbc": state.pbc, # [S, 3] + "charge": state.charge, # preserve charge + "spin": state.spin, # preserve spin } return BFGSState(**common_args) diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 872b3b587..fb09795d0 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -112,11 +112,16 @@ class LBFGSState(OptimState): these are scaled forces (forces @ deform_grad) for ASE compatibility. prev_positions: Previous-step positions [n_atoms, 3]. For cell filter, these are fractional coordinates for ASE compatibility. - s_history: Displacement history [h, n_atoms, 3] (global, not per-system) - y_history: Gradient-diff history [h, n_atoms, 3] (global, not per-system) + s_history: Displacement history [n_systems, h, max_atoms, 3] per-system. + For cell filter: [n_systems, h, max_atoms + 3, 3] to include cell DOFs. + May be padded when systems have different sizes. + y_history: Gradient-diff history [n_systems, h, max_atoms, 3] per-system. + For cell filter: [n_systems, h, max_atoms + 3, 3] to include cell DOFs. + May be padded when systems have different sizes. step_size: Per-system fixed step size [n_systems] alpha: Initial inverse Hessian scale (stiffness) [n_systems] n_iter: Per-system iteration counter [n_systems] (int32) + max_atoms: Atoms per system [n_systems] - used for size-binned operations """ prev_forces: torch.Tensor @@ -126,6 +131,7 @@ class LBFGSState(OptimState): step_size: torch.Tensor alpha: torch.Tensor n_iter: torch.Tensor + max_atoms: torch.Tensor # [S] atoms per system for padding support _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 "prev_forces", @@ -135,13 +141,12 @@ class LBFGSState(OptimState): "step_size", "alpha", "n_iter", - } - # Note (AG): s_history and y_history are global attributes because they are not - # per-system indexable, so they must be copied as-is on slice. - _global_attributes = OptimState._global_attributes | { # noqa: SLF001 + "max_atoms", "s_history", "y_history", } + # Attributes that need padding when concatenating different-sized systems + _padded_system_attributes: ClassVar[set[str]] = {"s_history", "y_history"} # there's no GradientDescentState, it's the same as OptimState diff --git a/torch_sim/state.py b/torch_sim/state.py index f7c7a4c84..f465938c4 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -967,7 +967,7 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor return type(state)(**filtered_attrs) # type: ignore[invalid-return-type] -def concatenate_states[T: SimState]( # noqa: C901 +def concatenate_states[T: SimState]( # noqa: C901, PLR0915 states: Sequence[T], device: torch.device | None = None ) -> T: """Concatenate a list of SimStates into a single SimState. @@ -1048,18 +1048,49 @@ def concatenate_states[T: SimState]( # noqa: C901 for prop, tensors in per_system_tensors.items(): # if tensors: if isinstance(tensors[0], torch.Tensor): + # TODO(AG): Is there a clean way to handle this? if prop in padded_attrs: # Pad tensors to max size before concatenating - # Assumes padding is needed on last two dimensions (e.g., hessian) - max_size = max(t.shape[-1] for t in tensors) - padded_tensors = [] - for t in tensors: - if t.shape[-1] < max_size: - pad_size = max_size - t.shape[-1] - # Pad last two dimensions (for 3D tensors like hessian) - t = torch.nn.functional.pad(t, (0, pad_size, 0, pad_size)) - padded_tensors.append(t) - concatenated[prop] = torch.cat(padded_tensors, dim=0) + # Detect tensor shape to determine padding strategy + first_tensor = tensors[0] + ndim = first_tensor.ndim + + if ndim == 3: + # Shape [S, D, D] required for BFGS hessian + # Pad last two dimensions + max_size = max(t.shape[-1] for t in tensors) + padded_tensors = [] + for t in tensors: + if t.shape[-1] < max_size: + pad_size = max_size - t.shape[-1] + t = torch.nn.functional.pad(t, (0, pad_size, 0, pad_size)) + padded_tensors.append(t) + concatenated[prop] = torch.cat(padded_tensors, dim=0) + elif ndim == 4: + # Shape [S, H, M, 3] required for L-BFGS history + # Pad dimension 2 (M) to max, and dimension 1 (H) to max + max_m = max(t.shape[2] for t in tensors) # max atoms dim + max_h = max(t.shape[1] for t in tensors) # max history dim + padded_tensors = [] + for t in tensors: + s_dim, h_dim, m_dim, last_dim = t.shape + if h_dim == 0: + # Special case: empty history, just create new shape + t = torch.zeros( + (s_dim, max_h, max_m, last_dim), + device=t.device, + dtype=t.dtype, + ) + elif m_dim < max_m or h_dim < max_h: + pad_m = max_m - m_dim + pad_h = max_h - h_dim + # For [S, H, M, 3]: pad M (dim 2) and H (dim 1) + t = torch.nn.functional.pad(t, (0, 0, 0, pad_m, 0, pad_h)) + padded_tensors.append(t) + concatenated[prop] = torch.cat(padded_tensors, dim=0) + else: + # Unknown shape, just concatenate without padding + concatenated[prop] = torch.cat(tensors, dim=0) else: concatenated[prop] = torch.cat(tensors, dim=0) else: # Non-tensor attributes, take first one (they should all be identical) From f51a0db5e7b4e660bb14b152983f80fafbd676b2 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 03:03:31 -0800 Subject: [PATCH 11/17] Update tests --- tests/test_autobatching.py | 5 - tests/test_optimizer_states.py | 52 ++++++- tests/test_optimizers.py | 276 ++++++++++++++++++++++++++++++++- 3 files changed, 325 insertions(+), 8 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 1baecc30d..ae5f4f8ca 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -713,11 +713,6 @@ def _group_states_by_size( return groups -@pytest.mark.skip( - reason="L-BFGS with InFlightAutoBatcher has a known issue: history tensors " - "become misaligned when systems are dynamically removed on convergence. " - "Use BinningAutoBatcher instead for L-BFGS." -) @pytest.mark.parametrize( "num_steps_per_batch", [ diff --git a/tests/test_optimizer_states.py b/tests/test_optimizer_states.py index eee408756..2052974e4 100644 --- a/tests/test_optimizer_states.py +++ b/tests/test_optimizer_states.py @@ -3,7 +3,7 @@ import pytest import torch -from torch_sim.optimizers.state import FireState, OptimState +from torch_sim.optimizers.state import BFGSState, FireState, LBFGSState, OptimState from torch_sim.state import SimState @@ -57,3 +57,53 @@ def test_fire_state_custom_values(sim_state: SimState, optim_data: dict) -> None assert torch.equal(state.dt, fire_data["dt"]) assert torch.equal(state.alpha, fire_data["alpha"]) assert torch.equal(state.n_pos, fire_data["n_pos"]) + + +def test_bfgs_state_custom_values(sim_state: SimState, optim_data: dict) -> None: + """Test BFGSState with custom values.""" + bfgs_data = { + "hessian": torch.eye(6, dtype=torch.float64).unsqueeze(0), # [1, 6, 6] + "prev_forces": optim_data["forces"].clone(), + "prev_positions": sim_state.positions.clone(), + "alpha": torch.tensor([70.0], dtype=torch.float64), + "max_step": torch.tensor([0.2], dtype=torch.float64), + "n_iter": torch.tensor([0], dtype=torch.int32), + "atom_idx_in_system": torch.arange(2, dtype=torch.int64), + "max_atoms": torch.tensor([2], dtype=torch.int64), + } + + state = BFGSState(**sim_state.attributes, **optim_data, **bfgs_data) + + assert torch.equal(state.hessian, bfgs_data["hessian"]) + assert torch.equal(state.prev_forces, bfgs_data["prev_forces"]) + assert torch.equal(state.prev_positions, bfgs_data["prev_positions"]) + assert torch.equal(state.alpha, bfgs_data["alpha"]) + assert torch.equal(state.max_step, bfgs_data["max_step"]) + assert torch.equal(state.n_iter, bfgs_data["n_iter"]) + assert torch.equal(state.atom_idx_in_system, bfgs_data["atom_idx_in_system"]) + assert torch.equal(state.max_atoms, bfgs_data["max_atoms"]) + + +def test_lbfgs_state_custom_values(sim_state: SimState, optim_data: dict) -> None: + """Test LBFGSState with custom values.""" + lbfgs_data = { + "prev_forces": optim_data["forces"].clone(), + "prev_positions": sim_state.positions.clone(), + "s_history": torch.zeros((1, 0, 2, 3), dtype=torch.float64), # [S, H, M, 3] + "y_history": torch.zeros((1, 0, 2, 3), dtype=torch.float64), + "step_size": torch.tensor([1.0], dtype=torch.float64), + "alpha": torch.tensor([70.0], dtype=torch.float64), + "n_iter": torch.tensor([0], dtype=torch.int32), + "max_atoms": torch.tensor([2], dtype=torch.int64), + } + + state = LBFGSState(**sim_state.attributes, **optim_data, **lbfgs_data) + + assert torch.equal(state.prev_forces, lbfgs_data["prev_forces"]) + assert torch.equal(state.prev_positions, lbfgs_data["prev_positions"]) + assert torch.equal(state.s_history, lbfgs_data["s_history"]) + assert torch.equal(state.y_history, lbfgs_data["y_history"]) + assert torch.equal(state.step_size, lbfgs_data["step_size"]) + assert torch.equal(state.alpha, lbfgs_data["alpha"]) + assert torch.equal(state.n_iter, lbfgs_data["n_iter"]) + assert torch.equal(state.max_atoms, lbfgs_data["max_atoms"]) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 0acb9835a..e8f0e828d 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -9,7 +9,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers import FireFlavor, FireState, OptimState +from torch_sim.optimizers import BFGSState, FireFlavor, FireState, LBFGSState, OptimState from torch_sim.state import SimState @@ -161,11 +161,269 @@ def test_fire_optimization( ) +def test_bfgs_optimization( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test that the BFGS optimizer actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + + # Initialize BFGS optimizer + state = ts.bfgs_init(current_sim_state, lj_model) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.bfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, f"BFGS optimization did not converge in {max_steps=}" + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"BFGS optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert max_force < 0.3, f"Forces should be small after optimization, got {max_force=}" + + assert not torch.allclose(state.positions, initial_state_positions), ( + "BFGS positions should have changed after optimization." + ) + + +def test_lbfgs_optimization( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test that the L-BFGS optimizer actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + + # Initialize L-BFGS optimizer + state = ts.lbfgs_init(current_sim_state, lj_model) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.lbfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, ( + f"L-BFGS optimization did not converge in {max_steps=}" + ) + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"L-BFGS optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert max_force < 0.3, f"Forces should be small after optimization, got {max_force=}" + + assert not torch.allclose(state.positions, initial_state_positions), ( + "L-BFGS positions should have changed after optimization." + ) + + +@pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) +def test_bfgs_cell_optimization( + ar_supercell_sim_state: SimState, + lj_model: ModelInterface, + cell_filter: ts.CellFilter, +) -> None: + """Test that BFGS with cell filter actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + current_cell = ( + ar_supercell_sim_state.cell.clone() + + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=current_cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + initial_state_cell = current_sim_state.cell.clone() + + # Initialize BFGS optimizer with cell filter + state = ts.bfgs_init( + state=current_sim_state, + model=lj_model, + cell_filter=cell_filter, + ) + + # Run optimization + energies = [1000.0, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.bfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, ( + f"BFGS {cell_filter.name} optimization did not converge in {max_steps=}" + ) + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"BFGS {cell_filter.name} optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + pressure = torch.trace(state.stress.squeeze(0)) / 3.0 + + assert torch.abs(pressure) < 0.05, ( + f"Pressure should be small after {cell_filter.name} optimization, got {pressure=}" + ) + assert max_force < 0.3, ( + f"Forces should be small after {cell_filter.name} optimization, got {max_force=}" + ) + + assert not torch.allclose(state.positions, initial_state_positions, atol=1e-5), ( + f"BFGS {cell_filter.name} positions should have changed after optimization." + ) + assert not torch.allclose(state.cell, initial_state_cell, atol=1e-5), ( + f"BFGS {cell_filter.name} cell should have changed after optimization." + ) + + +@pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) +def test_lbfgs_cell_optimization( + ar_supercell_sim_state: SimState, + lj_model: ModelInterface, + cell_filter: ts.CellFilter, +) -> None: + """Test that L-BFGS with cell filter actually minimizes energy.""" + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + current_cell = ( + ar_supercell_sim_state.cell.clone() + + torch.randn_like(ar_supercell_sim_state.cell) * 0.01 + ) + + current_sim_state = SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=current_cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + + initial_state_positions = current_sim_state.positions.clone() + initial_state_cell = current_sim_state.cell.clone() + + # Initialize L-BFGS optimizer with cell filter + state = ts.lbfgs_init( + state=current_sim_state, + model=lj_model, + cell_filter=cell_filter, + ) + + # Run optimization + energies = [1000.0, state.energy.item()] + max_steps = 1000 + steps_taken = 0 + + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.lbfgs_step(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + assert steps_taken < max_steps, ( + f"L-BFGS {cell_filter.name} optimization did not converge in {max_steps=}" + ) + + energies = energies[1:] + + # Check that energy decreased + assert energies[-1] < energies[0], ( + f"L-BFGS {cell_filter.name} optimization should reduce energy " + f"(initial: {energies[0]}, final: {energies[-1]})" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + pressure = torch.trace(state.stress.squeeze(0)) / 3.0 + + assert torch.abs(pressure) < 0.05, ( + f"Pressure should be small after {cell_filter.name} optimization, got {pressure=}" + ) + assert max_force < 0.3, ( + f"Forces should be small after {cell_filter.name} optimization, got {max_force=}" + ) + + assert not torch.allclose(state.positions, initial_state_positions, atol=1e-5), ( + f"L-BFGS {cell_filter.name} positions should have changed after optimization." + ) + assert not torch.allclose(state.cell, initial_state_cell, atol=1e-5), ( + f"L-BFGS {cell_filter.name} cell should have changed after optimization." + ) + + @pytest.mark.parametrize( ("optimizer_fn", "expected_state_type"), [ (ts.Optimizer.fire, FireState), (ts.Optimizer.gradient_descent, OptimState), + (ts.Optimizer.bfgs, BFGSState), + (ts.Optimizer.lbfgs, LBFGSState), ], ) def test_simple_optimizer_init_with_dict( @@ -410,6 +668,10 @@ def test_unit_cell_fire_optimization( 50.0, ), (ts.Optimizer.fire, ts.CellFilter.frechet, ts.CellFireState, 75.0), + (ts.Optimizer.bfgs, ts.CellFilter.unit, ts.CellBFGSState, 100), + (ts.Optimizer.bfgs, ts.CellFilter.frechet, ts.CellBFGSState, 75.0), + (ts.Optimizer.lbfgs, ts.CellFilter.unit, ts.CellLBFGSState, 100), + (ts.Optimizer.lbfgs, ts.CellFilter.frechet, ts.CellLBFGSState, 75.0), ], ) def test_cell_optimizer_init_with_dict_and_cell_factor( @@ -462,6 +724,10 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( ts.CellFilter.frechet, ts.CellOptimState, ), + (ts.Optimizer.bfgs, ts.CellFilter.unit, ts.CellBFGSState), + (ts.Optimizer.bfgs, ts.CellFilter.frechet, ts.CellBFGSState), + (ts.Optimizer.lbfgs, ts.CellFilter.unit, ts.CellLBFGSState), + (ts.Optimizer.lbfgs, ts.CellFilter.frechet, ts.CellLBFGSState), ], ) def test_cell_optimizer_init_cell_factor_none( @@ -918,6 +1184,10 @@ def energy_converged(current_energy: torch.Tensor, prev_energy: torch.Tensor) -> (ts.Optimizer.gradient_descent, None), (ts.Optimizer.fire, ts.CellFilter.unit), (ts.Optimizer.gradient_descent, ts.CellFilter.frechet), + (ts.Optimizer.bfgs, None), + (ts.Optimizer.lbfgs, None), + (ts.Optimizer.bfgs, ts.CellFilter.unit), + (ts.Optimizer.lbfgs, ts.CellFilter.frechet), ], ) def test_optimizer_preserves_charge_spin( @@ -956,8 +1226,10 @@ def test_optimizer_preserves_charge_spin( for _ in range(3): if optimizer_fn == ts.Optimizer.fire: opt_state = step_fn(state=opt_state, model=lj_model, dt_max=0.3) - else: + elif optimizer_fn == ts.Optimizer.gradient_descent: opt_state = step_fn(state=opt_state, model=lj_model, pos_lr=0.01, cell_lr=0.1) + else: + opt_state = step_fn(state=opt_state, model=lj_model) assert torch.allclose(opt_state.charge, original_charge) assert torch.allclose(opt_state.spin, original_spin) From 2da4c4481cdb92991c5ca385e87f9fb7643fa4e8 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 03:11:48 -0800 Subject: [PATCH 12/17] BFGS with constraints --- tests/test_constraints.py | 140 +++++++++++++++++++++++++++++++++++ torch_sim/optimizers/bfgs.py | 8 +- 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 3bbfb44ed..a98833e48 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -392,6 +392,111 @@ def test_fix_com_fire_optimization( assert torch.allclose(final_com, initial_com, atol=1e-4) +@pytest.mark.parametrize("optimizer", ["bfgs", "lbfgs"]) +def test_fix_atoms_bfgs_lbfgs_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + optimizer: str, +) -> None: + """Test FixAtoms constraint in BFGS/LBFGS optimization.""" + # Create a fresh copy with random displacement + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + indices = torch.tensor([0, 2], dtype=torch.long) + current_sim_state.constraints = [FixAtoms(atom_idx=indices)] + + # Initialize optimizer + if optimizer == "bfgs": + state = ts.bfgs_init(current_sim_state, lj_model) + step_fn = ts.bfgs_step + else: + state = ts.lbfgs_init(current_sim_state, lj_model) + step_fn = ts.lbfgs_step + + initial_position = state.positions[indices].clone() + + # Run optimization + energies = [1000, state.energy.item()] + max_steps = 500 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = step_fn(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + final_position = state.positions[indices] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + + +@pytest.mark.parametrize("optimizer", ["bfgs", "lbfgs"]) +def test_fix_com_bfgs_lbfgs_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + optimizer: str, +) -> None: + """Test FixCom constraint in BFGS/LBFGS optimization.""" + # Create a fresh copy with random displacement + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + current_sim_state.constraints = [FixCom([0])] + + # Initialize optimizer + if optimizer == "bfgs": + state = ts.bfgs_init(current_sim_state, lj_model) + step_fn = ts.bfgs_step + else: + state = ts.lbfgs_init(current_sim_state, lj_model) + step_fn = ts.lbfgs_step + + initial_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + # Run optimization + energies = [1000, state.energy.item()] + max_steps = 500 + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = step_fn(state=state, model=lj_model) + energies.append(state.energy.item()) + steps_taken += 1 + + final_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + assert torch.allclose(final_com, initial_com, atol=1e-4) + + def test_fix_atoms_validation() -> None: """Test FixAtoms construction and validation.""" # Boolean mask conversion @@ -587,6 +692,41 @@ def test_cell_optimization_with_constraints( assert len(state.constraints) > 0 +@pytest.mark.parametrize( + ("cell_filter", "optimizer"), + [ + (ts.CellFilter.unit, "bfgs"), + (ts.CellFilter.frechet, "bfgs"), + (ts.CellFilter.unit, "lbfgs"), + (ts.CellFilter.frechet, "lbfgs"), + ], +) +def test_cell_optimization_with_constraints_bfgs_lbfgs( + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + cell_filter: str, + optimizer: str, +) -> None: + """Test cell filters work with constraints for BFGS/LBFGS.""" + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.05 + ) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1])] + + if optimizer == "bfgs": + state = ts.bfgs_init(ar_supercell_sim_state, lj_model, cell_filter=cell_filter) + step_fn = ts.bfgs_step + else: + state = ts.lbfgs_init(ar_supercell_sim_state, lj_model, cell_filter=cell_filter) + step_fn = ts.lbfgs_step + + for _ in range(50): + state = step_fn(state, lj_model) + if state.forces.abs().max() < 0.05: + break + assert len(state.constraints) > 0 + + def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: """Test system-specific constraints in batched states.""" s1, s2 = ar_double_sim_state.split() diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index 3bdfbb0be..ccf84c4d4 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -193,6 +193,7 @@ def bfgs_init( "cell_filter": cell_filter_funcs, "charge": state.charge, # preserve charge "spin": state.spin, # preserve spin + "_constraints": state.constraints, # preserve constraints } cell_state = CellBFGSState(**common_args) @@ -233,6 +234,7 @@ def bfgs_init( "pbc": state.pbc, # [S, 3] "charge": state.charge, # preserve charge "spin": state.spin, # preserve spin + "_constraints": state.constraints, # preserve constraints } return BFGSState(**common_args) @@ -527,15 +529,15 @@ def bfgs_step( # noqa: C901, PLR0915 new_frac.unsqueeze(1), # [N, 1, 3] new_deform_grad[state.system_idx].transpose(-2, -1), # [N, 3, 3] ).squeeze(1) # [N, 3] - state.positions = new_positions # [N, 3] + state.set_constrained_positions(new_positions) # [N, 3] else: state.prev_positions = state.positions.clone() # [N, 3] state.prev_forces = state.forces.clone() # [N, 3] - state.positions = state.positions + flat_step # [N, 3] + state.set_constrained_positions(state.positions + flat_step) # [N, 3] # Evaluate new forces and energy model_output = model(state) - state.forces = model_output["forces"] # [N, 3] + state.set_constrained_forces(model_output["forces"]) # [N, 3] state.energy = model_output["energy"] # [S] if "stress" in model_output: state.stress = model_output["stress"] # [S, 3, 3] From bd360b40eafc2f5d4339aa62f2a941eb9a639699 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 03:13:21 -0800 Subject: [PATCH 13/17] Add batched lbfgs --- torch_sim/optimizers/lbfgs.py | 606 ++++++++++++++++++++++++++-------- 1 file changed, 471 insertions(+), 135 deletions(-) diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 6644c94cb..30fe2b321 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -4,20 +4,100 @@ L-BFGS is a quasi-Newton method that approximates the inverse Hessian using a limited history of position and gradient differences, making it memory-efficient for large systems while achieving superlinear convergence near the minimum. + +When cell_filter is active, forces are transformed using the deformation gradient +to work in the same scaled coordinate space as ASE's UnitCellFilter/FrechetCellFilter. +The prev_forces and prev_positions are stored in the scaled/fractional space to match +ASE's behavior exactly. """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch -import torch_sim.math as tsm +import torch_sim as ts +from torch_sim.optimizers import cell_filters +from torch_sim.optimizers.cell_filters import frechet_cell_filter_init from torch_sim.state import SimState from torch_sim.typing import StateDict if TYPE_CHECKING: from torch_sim.models.interface import ModelInterface - from torch_sim.optimizers import LBFGSState + from torch_sim.optimizers import CellLBFGSState, LBFGSState + from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs + + +def _atoms_to_padded( + x: torch.Tensor, + system_idx: torch.Tensor, + n_systems: int, + max_atoms: int, +) -> torch.Tensor: + """Convert atom-indexed [N, 3] to padded per-system [S, M, 3]. + + Args: + x: Tensor of shape [N, 3] where N = total atoms + system_idx: System index for each atom [N] + n_systems: Number of systems S + max_atoms: Maximum atoms per system M + + Returns: + Tensor of shape [S, M, 3] with zeros for padding + """ + device, dtype = x.device, x.dtype + out = torch.zeros((n_systems, max_atoms, 3), device=device, dtype=dtype) + # Create atom index within each system + atom_idx = torch.zeros_like(system_idx) + for sys in range(n_systems): + mask = system_idx == sys + atom_idx[mask] = torch.arange(mask.sum(), device=device) + out[system_idx, atom_idx] = x + return out + + +def _padded_to_atoms( + x: torch.Tensor, + system_idx: torch.Tensor, + n_atoms: int, +) -> torch.Tensor: + """Convert padded per-system [S, M, 3] to atom-indexed [N, 3]. + + Args: + x: Tensor of shape [S, M, 3] + system_idx: System index for each atom [N] + n_atoms: Total number of atoms N + + Returns: + Tensor of shape [N, 3] + """ + n_systems = x.shape[0] + device = x.device + # Create atom index within each system + atom_idx = torch.zeros(n_atoms, device=device, dtype=torch.long) + for sys in range(n_systems): + mask = system_idx == sys + atom_idx[mask] = torch.arange(mask.sum(), device=device) + return x[system_idx, atom_idx] # [N, 3] + + +def _per_system_vdot( + a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor +) -> torch.Tensor: + """Compute per-system dot product with padding mask. + + Args: + a: Tensor of shape [S, M, 3] + b: Tensor of shape [S, M, 3] + mask: Boolean mask [S, M] where True = valid atom + + Returns: + Tensor of shape [S] with per-system dot products + """ + # Element-wise product then sum over atoms and coordinates + prod = (a * b).sum(dim=-1) # [S, M] + prod = prod * mask.float() # Zero out padded atoms + return prod.sum(dim=-1) # [S] def lbfgs_init( @@ -26,12 +106,21 @@ def lbfgs_init( *, step_size: float = 0.1, alpha: float | None = None, -) -> "LBFGSState": + cell_filter: "CellFilter | CellFilterFuncs | None" = None, + **filter_kwargs: Any, +) -> "LBFGSState | CellLBFGSState": r"""Create an initial LBFGSState from a SimState or state dict. Initializes forces/energy, clears the (s, y) memory, and broadcasts the fixed step size to all systems. + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (global_max_atoms) + H = history length (starts at 0) + M_ext = M + 3 (extended with cell DOFs per system) + Args: state: Input state as SimState object or state parameter dict model: Model that computes energies, forces, and optionally stress @@ -41,9 +130,12 @@ def lbfgs_init( alpha: Initial inverse Hessian stiffness guess (ASE parameter). If provided (e.g. 70.0), fixes H0 = 1/alpha for all steps (ASE-style). If None (default), H0 is updated dynamically (Standard L-BFGS). + cell_filter: Filter for cell optimization (None for position-only optimization) + **filter_kwargs: Additional arguments passed to cell filter initialization Returns: - LBFGSState with initialized optimization tensors + LBFGSState with initialized optimization tensors, or CellLBFGSState if + cell_filter is provided Notes: The optimizer supports two modes of operation: @@ -56,73 +148,146 @@ def lbfgs_init( optimization, and the step is scaled by `step_size` (damping). This matches `ase.optimize.LBFGS(alpha=70.0, damping=1.0)`. """ - from torch_sim.optimizers import LBFGSState + from torch_sim.optimizers import CellLBFGSState, LBFGSState tensor_args = {"device": model.device, "dtype": model.dtype} if not isinstance(state, SimState): state = SimState(**state) - n_systems = state.n_systems + n_systems = state.n_systems # S + + # Compute max atoms per system for per-system history storage + counts = state.n_atoms_per_system # [S] + global_max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 # M + max_atoms = counts.clone() # [S] - each system's atom count # Get initial forces and energy from model model_output = model(state) - energy = model_output["energy"] - forces = model_output["forces"] - stress = model_output["stress"] - - # Initialize empty history tensors - # History shape: [max_history, n_atoms, 3] but we start with 0 entries - s_history = torch.zeros((0, state.n_atoms, 3), **tensor_args) - y_history = torch.zeros((0, state.n_atoms, 3), **tensor_args) + energy = model_output["energy"] # [S] + forces = model_output["forces"] # [N, 3] + stress = model_output.get("stress") # [S, 3, 3] or None + + # Initialize empty per-system history tensors + # History shape: [S, H, M, 3] where H=0 at start, M = global_max_atoms + s_history = torch.zeros( + (n_systems, 0, global_max_atoms, 3), **tensor_args + ) # [S, 0, M, 3] + y_history = torch.zeros( + (n_systems, 0, global_max_atoms, 3), **tensor_args + ) # [S, 0, M, 3] # Alpha tensor: 0.0 means dynamic, >0 means fixed alpha_val = 0.0 if alpha is None else alpha - alpha_tensor = torch.full((n_systems,), alpha_val, **tensor_args) + alpha_tensor = torch.full((n_systems,), alpha_val, **tensor_args) # [S] - return LBFGSState( + common_args = { # Copy SimState attributes - positions=state.positions.clone(), - masses=state.masses.clone(), - cell=state.cell.clone(), - atomic_numbers=state.atomic_numbers.clone(), - system_idx=state.system_idx.clone(), - pbc=state.pbc, + "positions": state.positions.clone(), # [N, 3] + "masses": state.masses.clone(), # [N] + "cell": state.cell.clone(), # [S, 3, 3] + "atomic_numbers": state.atomic_numbers.clone(), # [N] + "system_idx": state.system_idx.clone(), # [N] + "pbc": state.pbc, # [S, 3] + "charge": state.charge, # preserve charge + "spin": state.spin, # preserve spin + "_constraints": state.constraints, # preserve constraints # Optimization state - forces=forces, - energy=energy, - stress=stress, + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None # L-BFGS specific state - prev_forces=forces.clone(), - prev_positions=state.positions.clone(), - s_history=s_history, - y_history=y_history, - step_size=torch.full((n_systems,), step_size, **tensor_args), - alpha=alpha_tensor, - n_iter=torch.zeros((n_systems,), device=model.device, dtype=torch.int32), - ) - - -def lbfgs_step( # noqa: PLR0915 - state: "LBFGSState", + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "s_history": s_history, # [S, 0, M, 3] + "y_history": y_history, # [S, 0, M, 3] + "step_size": torch.full((n_systems,), step_size, **tensor_args), # [S] + "alpha": alpha_tensor, # [S] + "n_iter": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), + "max_atoms": max_atoms, # [S] atoms per system for padding + } + + if cell_filter is not None: + cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) + + # At initialization, deform_grad is identity since reference_cell = current_cell + # Store prev_positions as fractional (same as Cartesian for identity deform_grad) + # Store prev_forces as scaled (same as Cartesian for identity deform_grad) + reference_cell = state.cell.clone() # [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + reference_cell.mT, state.cell.mT + ) # [S, 3, 3] + + # Initial fractional positions = positions + # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] -> [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Initial scaled forces = forces @ deform_grad = forces + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] + scaled_forces = torch.bmm( + forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(-1) # [N, 3] + + common_args["reference_cell"] = reference_cell # [S, 3, 3] + common_args["cell_filter"] = cell_filter_funcs + # Store fractional positions and scaled forces for ASE compatibility + common_args["prev_positions"] = frac_positions # [N, 3] + common_args["prev_forces"] = scaled_forces # [N, 3] + + # Extended per-system history includes cell DOFs (3 "virtual atoms" per system) + # History shape: [S, H, M+3, 3] where M = global_max_atoms + extended_size_per_system = global_max_atoms + 3 # M_ext = M + 3 + common_args["s_history"] = torch.zeros( + (n_systems, 0, extended_size_per_system, 3), **tensor_args + ) # [S, 0, M_ext, 3] + common_args["y_history"] = torch.zeros( + (n_systems, 0, extended_size_per_system, 3), **tensor_args + ) # [S, 0, M_ext, 3] + + cell_state = CellLBFGSState(**common_args) + + # Initialize cell-specific attributes + # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] + init_fn(cell_state, model, **filter_kwargs) + + # Store prev_cell_positions and prev_cell_forces for history update + cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] + cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + + return cell_state + + return LBFGSState(**common_args) + + +def lbfgs_step( # noqa: PLR0915, C901 + state: "LBFGSState | CellLBFGSState", model: "ModelInterface", *, max_history: int = 10, max_step: float = 0.2, curvature_eps: float = 1e-12, -) -> "LBFGSState": +) -> "LBFGSState | CellLBFGSState": r"""Advance one L-BFGS iteration using the two-loop recursion. Computes the search direction via the two-loop recursion, applies a fixed step with optional per-system capping, evaluates new forces and energy, and updates the limited-memory history with a curvature check. - Algorithm (per system s): - 1) Evaluate gradient g_k = ∇E(x_k) = -f(x_k) - 2) Perform L-BFGS two-loop recursion using up to `max_history` pairs - (s_i, y_i) to compute d_k = -H_k g_k - 3) Fixed step update with optional per-system step capping by `max_step` - 4) Curvature check and history update: accept (s_k, y_k) if ⟨y_k, s_k⟩ > ε + When cell_filter is active, forces are transformed using the deformation + gradient to work in the same scaled coordinate space as ASE's cell filters. + The prev_positions are stored as fractional coordinates and prev_forces as + scaled forces, exactly matching ASE's pos0/forces0. + + Shape notation: + N = total atoms across all systems (n_atoms) + S = number of systems (n_systems) + M = max atoms per system (history dimension) + H = current history length + M_ext = M + 3 (extended with cell DOFs per system) Args: state: Current L-BFGS optimization state @@ -144,143 +309,314 @@ def lbfgs_step( # noqa: PLR0915 References: - Nocedal & Wright, Numerical Optimization (L-BFGS two-loop recursion). """ + from torch_sim.optimizers import CellLBFGSState + + is_cell_state = isinstance(state, CellLBFGSState) device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 - - # Current gradient - g = -state.forces + n_systems = state.n_systems # S + n_atoms = state.n_atoms # N + + # Derive max_atoms from history shape: [S, H, M, 3] or [S, H, M_ext, 3] + history_dim = state.s_history.shape[2] # M or M_ext + if is_cell_state: + max_atoms_ext = history_dim # M_ext = M + 3 + max_atoms = max_atoms_ext - 3 # M + else: + max_atoms = history_dim # M + max_atoms_ext = max_atoms + + # Create atom index within each system for padding/unpadding + atom_idx_in_sys = torch.zeros(n_atoms, device=device, dtype=torch.long) + for sys in range(n_systems): + mask = state.system_idx == sys + atom_idx_in_sys[mask] = torch.arange(mask.sum(), device=device) + + # Create valid atom mask for per-system operations: [S, M] + atom_mask = torch.zeros((n_systems, max_atoms), device=device, dtype=torch.bool) + for sys in range(n_systems): + n_atoms_sys = int(state.max_atoms[sys].item()) + atom_mask[sys, :n_atoms_sys] = True + + # Extended mask including cell DOFs: [S, M_ext] + if is_cell_state: + ext_mask = torch.cat( + [ + atom_mask, + torch.ones((n_systems, 3), device=device, dtype=torch.bool), + ], + dim=1, + ) # [S, M_ext] + else: + ext_mask = atom_mask # [S, M] + + if is_cell_state: + # Get current deformation gradient + # reference_cell.mT: [S, 3, 3], row_vector_cell: [S, 3, 3] + cur_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + + # Transform forces to scaled coordinates + # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] + forces_scaled = torch.bmm( + state.forces.unsqueeze(1), # [N, 1, 3] + cur_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) # [N, 3] + + # Current fractional positions + # positions: [N, 3] -> frac_positions: [N, 3] + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # Convert to padded per-system format: [S, M, 3] + g_atoms = _atoms_to_padded(-forces_scaled, state.system_idx, n_systems, max_atoms) + # Cell forces: [S, 3, 3] -> [S, 3, 3] + g_cell = -state.cell_forces # [S, 3, 3] + # Extended gradient: [S, M_ext, 3] = [S, M+3, 3] + g = torch.cat([g_atoms, g_cell], dim=1) # [S, M_ext, 3] + else: + # Convert to padded per-system format: [S, M, 3] + g = _atoms_to_padded(-state.forces, state.system_idx, n_systems, max_atoms) # Two-loop recursion to compute search direction d = -H_k g_k - q = g.clone() - alphas: list[torch.Tensor] = [] # per-history, shape [n_systems] + # History shape: [S, H, M_ext, 3] or [S, H, M, 3] + cur_history_len = state.s_history.shape[1] # H + q = g.clone() # [S, M_ext, 3] or [S, M, 3] + alphas: list[torch.Tensor] = [] # list of [S] tensors # First loop (from newest to oldest) - for i in range(state.s_history.shape[0] - 1, -1, -1): - s_i = state.s_history[i] - y_i = state.y_history[i] + for i in range(cur_history_len - 1, -1, -1): + s_i = state.s_history[:, i] # [S, M_ext, 3] or [S, M, 3] + y_i = state.y_history[:, i] # [S, M_ext, 3] or [S, M, 3] - ys = tsm.batched_vdot(y_i, s_i, state.system_idx) # y^T s per system + # ys = y^T s per system: [S] + ys = _per_system_vdot(y_i, s_i, ext_mask) # [S] rho = torch.where( ys.abs() > curvature_eps, 1.0 / (ys + eps), torch.zeros_like(ys), - ) - sq = tsm.batched_vdot(s_i, q, state.system_idx) - alpha = rho * sq + ) # [S] + sq = _per_system_vdot(s_i, q, ext_mask) # [S] + alpha = rho * sq # [S] alphas.append(alpha) - # q <- q - alpha * y_i (broadcast per system to atoms) - alpha_atom = alpha[state.system_idx].unsqueeze(-1) - q = q - alpha_atom * y_i + # q <- q - alpha * y_i (broadcast alpha to [S, 1, 1]) + q = q - alpha.view(-1, 1, 1) * y_i # [S, M_ext, 3] # Initial H0 scaling: gamma = (s^T y)/(y^T y) using the last pair - # Dynamic gamma (Standard L-BFGS) - if state.s_history.shape[0] > 0: - s_last = state.s_history[-1] - y_last = state.y_history[-1] - sy = tsm.batched_vdot(s_last, y_last, state.system_idx) - yy = tsm.batched_vdot(y_last, y_last, state.system_idx) + if cur_history_len > 0: + s_last = state.s_history[:, -1] # [S, M_ext, 3] + y_last = state.y_history[:, -1] # [S, M_ext, 3] + sy = _per_system_vdot(s_last, y_last, ext_mask) # [S] + yy = _per_system_vdot(y_last, y_last, ext_mask) # [S] gamma_dynamic = torch.where( yy.abs() > curvature_eps, sy / (yy + eps), torch.ones_like(yy), - ) + ) # [S] else: - gamma_dynamic = torch.ones((state.n_systems,), device=device, dtype=dtype) + gamma_dynamic = torch.ones((n_systems,), device=device, dtype=dtype) # [S] # Fixed gamma (ASE style: 1/alpha) # If state.alpha > 0, use that. Else use dynamic. - is_fixed = state.alpha > 1e-6 - gamma_fixed = 1.0 / (state.alpha + eps) - gamma = torch.where(is_fixed, gamma_fixed, gamma_dynamic) + is_fixed = state.alpha > 1e-6 # [S] bool + gamma_fixed = 1.0 / (state.alpha + eps) # [S] + gamma = torch.where(is_fixed, gamma_fixed, gamma_dynamic) # [S] - z = gamma[state.system_idx].unsqueeze(-1) * q + # z = gamma * q (broadcast gamma to [S, 1, 1]) + z = gamma.view(-1, 1, 1) * q # [S, M_ext, 3] # Second loop (from oldest to newest) - for i in range(state.s_history.shape[0]): - s_i = state.s_history[i] - y_i = state.y_history[i] + for i in range(cur_history_len): + s_i = state.s_history[:, i] # [S, M_ext, 3] + y_i = state.y_history[:, i] # [S, M_ext, 3] - ys = tsm.batched_vdot(y_i, s_i, state.system_idx) + ys = _per_system_vdot(y_i, s_i, ext_mask) # [S] rho = torch.where( ys.abs() > curvature_eps, 1.0 / (ys + eps), torch.zeros_like(ys), - ) - yz = tsm.batched_vdot(y_i, z, state.system_idx) - beta = rho * yz + ) # [S] + yz = _per_system_vdot(y_i, z, ext_mask) # [S] + beta = rho * yz # [S] - alpha = alphas[state.s_history.shape[0] - 1 - i] + alpha_i = alphas[cur_history_len - 1 - i] # [S] # z <- z + s_i * (alpha - beta) - coeff = (alpha - beta)[state.system_idx].unsqueeze(-1) - z = z + s_i * coeff + coeff = (alpha_i - beta).view(-1, 1, 1) # [S, 1, 1] + z = z + s_i * coeff # [S, M_ext, 3] - d = -z # search direction + d = -z # search direction: [S, M_ext, 3] - # Optional per-system max step cap - # Compute per-atom step with current step_size - t_atoms = state.step_size[state.system_idx].unsqueeze(-1) - step = t_atoms * d + # Apply step_size scaling per system: [S, 1, 1] + step = state.step_size.view(-1, 1, 1) * d # [S, M_ext, 3] - # Per-atom norms - norms = torch.linalg.norm(step, dim=1) - - # Per-system max norm - sys_max = torch.zeros(state.n_systems, device=device, dtype=dtype) - sys_max.scatter_reduce_(0, state.system_idx, norms, reduce="amax", include_self=False) + # Per-system max norm (only over valid atoms/DOFs) + step_norms = torch.linalg.norm(step, dim=-1) # [S, M_ext] + step_norms = step_norms * ext_mask.float() # Zero out padded + sys_max = step_norms.max(dim=1).values # [S] # Scaling factors per system: <= 1.0 scale = torch.where( sys_max > max_step, max_step / (sys_max + eps), torch.ones_like(sys_max), - ) - scale_atoms = scale[state.system_idx].unsqueeze(-1) - step = scale_atoms * step - - # Update positions - new_positions = state.positions + step + ) # [S] + step = scale.view(-1, 1, 1) * step # [S, M_ext, 3] + + # Split step into position and cell components + if is_cell_state: + step_padded = step[:, :max_atoms] # [S, M, 3] + step_cell = step[:, max_atoms:] # [S, 3, 3] + # Convert padded step to atom-level + step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms) + else: + step_padded = step # [S, M, 3] + step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms) + + # Save previous state for history update + # For cell state: store fractional positions and scaled forces (ASE convention) + if is_cell_state: + state.prev_positions = frac_positions.clone() # [N, 3] (fractional) + state.prev_forces = forces_scaled.clone() # [N, 3] (scaled) + state.prev_cell_positions = state.cell_positions.clone() # [S, 3, 3] + state.prev_cell_forces = state.cell_forces.clone() # [S, 3, 3] + + # Apply cell step + dr_cell = step_cell # [S, 3, 3] + cell_positions_new = state.cell_positions + dr_cell # [S, 3, 3] + state.cell_positions = cell_positions_new # [S, 3, 3] + + # Determine if Frechet filter + init_fn, _step_fn = state.cell_filter + is_frechet = init_fn is frechet_cell_filter_init + + if is_frechet: + # Frechet: deform_grad = exp(cell_positions / cell_factor) + cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1) + deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3] + deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3] + else: + # UnitCell: deform_grad = cell_positions / cell_factor + cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) + deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3] + + # Update cell: new_cell = reference_cell @ deform_grad^T + # reference_cell.mT: [S, 3, 3], deform_grad_new: [S, 3, 3] + state.row_vector_cell = torch.bmm( + state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + ) # [S, 3, 3] + + # Apply position step in fractional space, then convert to Cartesian + new_frac = frac_positions + step_positions # [N, 3] + + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + # new_positions = new_frac @ deform_grad^T + new_positions = torch.bmm( + new_frac.unsqueeze(1), # [N, 1, 3] + new_deform_grad[state.system_idx].transpose(-2, -1), # [N, 3, 3] + ).squeeze(1) # [N, 3] + state.set_constrained_positions(new_positions) # [N, 3] + else: + state.prev_positions = state.positions.clone() # [N, 3] + state.prev_forces = state.forces.clone() # [N, 3] + state.set_constrained_positions(state.positions + step_positions) # [N, 3] # Evaluate new forces/energy - state.positions = new_positions model_output = model(state) - new_forces = model_output["forces"] - new_energy = model_output["energy"] - new_stress = model_output["stress"] - - # Build new (s, y) - s_new = state.positions - state.prev_positions - y_new = -new_forces - (-state.prev_forces) # g_new - g_prev = -(f_new - f_prev) - - # Curvature check per system; if bad, clear history (conservative) - sy = tsm.batched_vdot(s_new, y_new, state.system_idx) - bad_curv = sy <= curvature_eps - - if bad_curv.any(): - # Clear entire history to preserve correctness - s_hist = torch.zeros((0, state.n_atoms, 3), device=device, dtype=dtype) - y_hist = torch.zeros((0, state.n_atoms, 3), device=device, dtype=dtype) + new_forces = model_output["forces"] # [N, 3] + new_energy = model_output["energy"] # [S] + new_stress = model_output.get("stress") # [S, 3, 3] or None + + # Update cell forces for next step: [S, 3, 3] + if is_cell_state: + cell_filters.compute_cell_forces(model_output, state) + + # Build new (s, y) for history in per-system format [S, M_ext, 3] or [S, M, 3] + # s = position difference, y = gradient difference + if is_cell_state: + # Get new scaled forces and fractional positions for history + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) # [S, 3, 3] + # new_forces: [N, 3] -> new_forces_scaled: [N, 3] + new_forces_scaled = torch.bmm( + new_forces.unsqueeze(1), # [N, 1, 3] + new_deform_grad[state.system_idx], # [N, 3, 3] + ).squeeze(1) # [N, 3] + # positions: [N, 3] -> new_frac_positions: [N, 3] + new_frac_positions = torch.linalg.solve( + new_deform_grad[state.system_idx], # [N, 3, 3] + state.positions.unsqueeze(-1), # [N, 3, 1] + ).squeeze(-1) # [N, 3] + + # s_new_pos = frac_pos_new - frac_pos_old: [N, 3] -> [S, M, 3] + s_new_pos_atoms = new_frac_positions - state.prev_positions # [N, 3] + s_new_pos = _atoms_to_padded( + s_new_pos_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + # s_new_cell = cell_pos_new - cell_pos_old: [S, 3, 3] + s_new_cell = state.cell_positions - state.prev_cell_positions # [S, 3, 3] + # Concatenate to extended format: [S, M_ext, 3] + s_new = torch.cat([s_new_pos, s_new_cell], dim=1) # [S, M_ext, 3] + + # y_new = grad_diff for positions and cell (gradient = -forces) + # y = grad_new - grad_old = -forces_new - (-forces_old) = forces_old - forces_new + y_new_pos_atoms = -new_forces_scaled - (-state.prev_forces) # [N, 3] + y_new_pos = _atoms_to_padded( + y_new_pos_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + y_new_cell = -state.cell_forces - (-state.prev_cell_forces) # [S, 3, 3] + y_new = torch.cat([y_new_pos, y_new_cell], dim=1) # [S, M_ext, 3] else: - # Append and trim if needed - if state.s_history.shape[0] == 0: - s_hist = s_new.unsqueeze(0) - y_hist = y_new.unsqueeze(0) - else: - s_hist = torch.cat([state.s_history, s_new.unsqueeze(0)], dim=0) - y_hist = torch.cat([state.y_history, y_new.unsqueeze(0)], dim=0) - if s_hist.shape[0] > max_history: - s_hist = s_hist[-max_history:] - y_hist = y_hist[-max_history:] + # s_new = pos_new - pos_old: [N, 3] -> [S, M, 3] + s_new_atoms = state.positions - state.prev_positions # [N, 3] + s_new = _atoms_to_padded( + s_new_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + # y_new = grad_diff: [N, 3] -> [S, M, 3] + y_new_atoms = -new_forces - (-state.prev_forces) # [N, 3] + y_new = _atoms_to_padded( + y_new_atoms, state.system_idx, n_systems, max_atoms + ) # [S, M, 3] + + # Append history and trim if needed + # Note: ASE's L-BFGS doesn't have a curvature check for adding to history. + # Invalid curvatures are handled in the two-loop by checking rho. + # History tensors: [S, H, M_ext, 3] or [S, H, M, 3] + cur_history_len = state.s_history.shape[1] # H + if cur_history_len == 0: + # First entry: [S, 1, M_ext, 3] or [S, 1, M, 3] + s_hist = s_new.unsqueeze(1) # [S, 1, M_ext, 3] + y_hist = y_new.unsqueeze(1) # [S, 1, M_ext, 3] + else: + # Append new entry: [S, H, ...] cat [S, 1, ...] -> [S, H+1, ...] + s_hist = torch.cat([state.s_history, s_new.unsqueeze(1)], dim=1) + y_hist = torch.cat([state.y_history, y_new.unsqueeze(1)], dim=1) + # Trim to max_history + if s_hist.shape[1] > max_history: + s_hist = s_hist[:, -max_history:] # [S, max_history, ...] + y_hist = y_hist[:, -max_history:] # Update state - state.forces = new_forces - state.energy = new_energy - state.stress = new_stress - - state.prev_forces = new_forces.clone() - state.prev_positions = state.positions.clone() - state.s_history = s_hist - state.y_history = y_hist - state.n_iter = state.n_iter + 1 + state.set_constrained_forces(new_forces) # [N, 3] + state.energy = new_energy # [S] + state.stress = new_stress # [S, 3, 3] or None + + if is_cell_state: + # Store fractional/scaled for next iteration + state.prev_positions = new_frac_positions.clone() # [N, 3] (fractional) + state.prev_forces = new_forces_scaled.clone() # [N, 3] (scaled) + else: + state.prev_forces = new_forces.clone() # [N, 3] + state.prev_positions = state.positions.clone() # [N, 3] + + state.s_history = s_hist # [S, H, M_ext, 3] or [S, H, M, 3] + state.y_history = y_hist # [S, H, M_ext, 3] or [S, H, M, 3] + state.n_iter = state.n_iter + 1 # [S] return state From 0cd42b37a309ecc59800141258f360c8c4c1883d Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 05:28:29 -0800 Subject: [PATCH 14/17] bump max history and fix shape issue --- examples/scripts/2_structural_optimization.py | 16 ++++++++-------- torch_sim/optimizers/lbfgs.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/scripts/2_structural_optimization.py b/examples/scripts/2_structural_optimization.py index 3adf0b4f2..0f8d7d052 100644 --- a/examples/scripts/2_structural_optimization.py +++ b/examples/scripts/2_structural_optimization.py @@ -31,7 +31,7 @@ # Number of steps to run SMOKE_TEST = os.getenv("CI") is not None -N_steps = 10 if SMOKE_TEST else 100 +N_steps = 10 if SMOKE_TEST else 500 # ============================================================================ @@ -111,7 +111,7 @@ # Run optimization for step in range(N_steps): - if step % (N_steps // 5) == 0: + if step % 100 == 0: print(f"Step {step}: Potential energy: {state.energy[0].item()} eV") state = ts.fire_step(state=state, model=lj_model, dt_max=0.01) @@ -174,7 +174,7 @@ print("\nRunning FIRE:") for step in range(N_steps): - if step % (N_steps // 5) == 0: + if step % 100 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") state = ts.fire_step(state=state, model=model, dt_max=0.01) @@ -254,7 +254,7 @@ print("\nRunning batched unit cell gradient descent:") for step in range(N_steps): - if step % (N_steps // 5) == 0: + if step % 100 == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -308,7 +308,7 @@ print("\nRunning batched unit cell FIRE:") for step in range(N_steps): - if step % (N_steps // 5) == 0: + if step % 100 == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -360,7 +360,7 @@ print("\nRunning batched frechet cell filter with FIRE:") for step in range(N_steps): - if step % (N_steps // 5) == 0: + if step % 100 == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -411,7 +411,7 @@ print("\nRunning L-BFGS:") for step in range(N_steps): - if step % (N_steps // 5) == 0: + if step % 100 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") state = ts.lbfgs_step(state=state, model=model, max_history=100) @@ -444,7 +444,7 @@ print("\nRunning BFGS:") for step in range(N_steps): - if step % (N_steps // 5) == 0: + if step % 100 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") state = ts.bfgs_step(state=state, model=model) diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 30fe2b321..f4dd79a34 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -230,7 +230,7 @@ def lbfgs_init( scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] - ).squeeze(-1) # [N, 3] + ).squeeze(1) # [N, 3] common_args["reference_cell"] = reference_cell # [S, 3, 3] common_args["cell_filter"] = cell_filter_funcs @@ -267,7 +267,7 @@ def lbfgs_step( # noqa: PLR0915, C901 state: "LBFGSState | CellLBFGSState", model: "ModelInterface", *, - max_history: int = 10, + max_history: int = 20, max_step: float = 0.2, curvature_eps: float = 1e-12, ) -> "LBFGSState | CellLBFGSState": From 69ce9f43e0870bb449fd0a3d44c2ba84ba1562c7 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 05:38:26 -0800 Subject: [PATCH 15/17] add optimizer test and remove redundant autobatcher test --- examples/scripts/2_structural_optimization.py | 14 +- tests/test_autobatching.py | 66 ++------- tests/test_optimizers.py | 140 ++++++++++++++++++ 3 files changed, 162 insertions(+), 58 deletions(-) diff --git a/examples/scripts/2_structural_optimization.py b/examples/scripts/2_structural_optimization.py index 0f8d7d052..ccf5b9019 100644 --- a/examples/scripts/2_structural_optimization.py +++ b/examples/scripts/2_structural_optimization.py @@ -111,7 +111,7 @@ # Run optimization for step in range(N_steps): - if step % 100 == 0: + if step % 20 == 0: print(f"Step {step}: Potential energy: {state.energy[0].item()} eV") state = ts.fire_step(state=state, model=lj_model, dt_max=0.01) @@ -174,7 +174,7 @@ print("\nRunning FIRE:") for step in range(N_steps): - if step % 100 == 0: + if step % 20 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") state = ts.fire_step(state=state, model=model, dt_max=0.01) @@ -254,7 +254,7 @@ print("\nRunning batched unit cell gradient descent:") for step in range(N_steps): - if step % 100 == 0: + if step % 20 == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -308,7 +308,7 @@ print("\nRunning batched unit cell FIRE:") for step in range(N_steps): - if step % 100 == 0: + if step % 20 == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -360,7 +360,7 @@ print("\nRunning batched frechet cell filter with FIRE:") for step in range(N_steps): - if step % 100 == 0: + if step % 20 == 0: P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3 P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3 @@ -411,7 +411,7 @@ print("\nRunning L-BFGS:") for step in range(N_steps): - if step % 100 == 0: + if step % 20 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") state = ts.lbfgs_step(state=state, model=model, max_history=100) @@ -444,7 +444,7 @@ print("\nRunning BFGS:") for step in range(N_steps): - if step % 100 == 0: + if step % 20 == 0: print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}") state = ts.bfgs_step(state=state, model=model) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index ae5f4f8ca..c3e9aea23 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -620,7 +620,7 @@ def test_in_flight_with_bfgs( lj_model: LennardJonesModel, num_steps_per_batch: int, ) -> None: - """Test InFlightAutoBatcher with BFGS optimizer (matching FIRE test structure).""" + """Test InFlightAutoBatcher with BFGS optimizer.""" si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) fe_bfgs_state = ts.bfgs_init( fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit @@ -668,7 +668,7 @@ def test_binning_auto_batcher_with_bfgs( fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, ) -> None: - """Test BinningAutoBatcher with BFGS optimizer (matching FIRE test structure).""" + """Test BinningAutoBatcher with BFGS optimizer.""" si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) fe_bfgs_state = ts.bfgs_init( fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit @@ -695,24 +695,6 @@ def test_binning_auto_batcher_with_bfgs( assert len(all_finished_states) == len(bfgs_states) -def _group_states_by_size( - states: list[ts.SimState], -) -> list[list[tuple[int, ts.SimState]]]: - """Group states by n_atoms, preserving original indices for order restoration. - - Used for L-BFGS which requires same-sized systems in each batch due to - history tensor shapes being dependent on n_atoms. - """ - from itertools import groupby - - indexed_states = list(enumerate(states)) - sorted_states = sorted(indexed_states, key=lambda x: x[1].n_atoms) - groups = [] - for _, group in groupby(sorted_states, key=lambda x: x[1].n_atoms): - groups.append(list(group)) - return groups - - @pytest.mark.parametrize( "num_steps_per_batch", [ @@ -726,7 +708,7 @@ def test_in_flight_with_lbfgs( lj_model: LennardJonesModel, num_steps_per_batch: int, ) -> None: - """Test InFlightAutoBatcher with L-BFGS optimizer (matching FIRE test structure).""" + """Test InFlightAutoBatcher with L-BFGS optimizer.""" si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) fe_lbfgs_state = ts.lbfgs_init( fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit @@ -774,7 +756,7 @@ def test_binning_auto_batcher_with_lbfgs( fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, ) -> None: - """Test BinningAutoBatcher with L-BFGS optimizer (matching FIRE test structure).""" + """Test BinningAutoBatcher with L-BFGS optimizer.""" si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit) fe_lbfgs_state = ts.lbfgs_init( fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit @@ -785,35 +767,17 @@ def test_binning_auto_batcher_with_lbfgs( for state in lbfgs_states: state.positions += torch.randn_like(state.positions) * 0.01 - # Group by size and process each group separately - size_groups = _group_states_by_size(lbfgs_states) - all_finished_with_indices: list[tuple[int, ts.SimState]] = [] - total_batches = 0 - - for group in size_groups: - original_indices, group_states = zip(*group, strict=True) - group_states_list = list(group_states) - - batcher = BinningAutoBatcher( - model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000 - ) - batcher.load_states(group_states_list) - - finished_states = [] - for batch, _ in batcher: - total_batches += 1 - for _ in range(5): - batch = ts.lbfgs_step(state=batch, model=lj_model) - finished_states.extend(batch.split()) - - restored = batcher.restore_original_order(finished_states) - for idx, finished_state in zip(original_indices, restored, strict=True): - all_finished_with_indices.append((idx, finished_state)) + batcher = BinningAutoBatcher( + model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000 + ) + batcher.load_states(lbfgs_states) - # Sort by original index to restore order - all_finished_with_indices.sort(key=lambda x: x[0]) - all_finished_states = [s for _, s in all_finished_with_indices] + all_finished_states: list[ts.SimState] = [] + total_batches = 0 + for batch, _ in batcher: + total_batches += 1 # noqa: SIM113 + for _ in range(5): + batch = ts.lbfgs_step(state=batch, model=lj_model) + all_finished_states.extend(batch.split()) assert len(all_finished_states) == len(lbfgs_states) - for restored, original in zip(all_finished_states, lbfgs_states, strict=True): - assert torch.all(restored.atomic_numbers == original.atomic_numbers) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index e8f0e828d..33b3ac862 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -341,6 +341,76 @@ def test_bfgs_cell_optimization( ) +def test_unit_cell_bfgs_multi_batch( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test BFGS optimization with multiple batches.""" + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + multi_state = ts.concatenate_states( + [ar_supercell_sim_state_1, ar_supercell_sim_state_2], + device=ar_supercell_sim_state.device, + ) + + # Initialize BFGS optimizer with unit cell filter + state = ts.bfgs_init( + state=multi_state, model=lj_model, cell_filter=ts.CellFilter.unit + ) + initial_state = copy.deepcopy(state) + + # Run optimization + prev_energy = torch.ones(2, device=state.device, dtype=state.energy.dtype) * 1000 + current_energy = initial_state.energy + step = 0 + while not torch.allclose(current_energy, prev_energy, atol=1e-9): + prev_energy = current_energy + state = ts.bfgs_step(state=state, model=lj_model) + current_energy = state.energy + + step += 1 + if step > 500: + raise ValueError("BFGS optimization did not converge") + + # Check that we actually optimized + assert step > 5 + + # Check that energy decreased for both batches + assert torch.all(state.energy < initial_state.energy), ( + "BFGS optimization should reduce energy for all batches" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert torch.all(max_force < 0.2), ( + f"Forces should be small after optimization, got {max_force=}" + ) + + n_ar_atoms = ar_supercell_sim_state.n_atoms + assert not torch.allclose( + state.positions[:n_ar_atoms], multi_state.positions[:n_ar_atoms] + ) + assert not torch.allclose( + state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] + ) + + # We are evolving identical systems + assert current_energy[0] == current_energy[1] + + @pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) def test_lbfgs_cell_optimization( ar_supercell_sim_state: SimState, @@ -417,6 +487,76 @@ def test_lbfgs_cell_optimization( ) +def test_unit_cell_lbfgs_multi_batch( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """Test L-BFGS optimization with multiple batches.""" + generator = torch.Generator(device=ar_supercell_sim_state.device) + + ar_supercell_sim_state_1 = copy.deepcopy(ar_supercell_sim_state) + ar_supercell_sim_state_2 = copy.deepcopy(ar_supercell_sim_state) + + for state in (ar_supercell_sim_state_1, ar_supercell_sim_state_2): + generator.manual_seed(43) + state.positions += ( + torch.randn( + state.positions.shape, + device=state.device, + generator=generator, + ) + * 0.1 + ) + + multi_state = ts.concatenate_states( + [ar_supercell_sim_state_1, ar_supercell_sim_state_2], + device=ar_supercell_sim_state.device, + ) + + # Initialize L-BFGS optimizer with unit cell filter + state = ts.lbfgs_init( + state=multi_state, model=lj_model, cell_filter=ts.CellFilter.unit + ) + initial_state = copy.deepcopy(state) + + # Run optimization + prev_energy = torch.ones(2, device=state.device, dtype=state.energy.dtype) * 1000 + current_energy = initial_state.energy + step = 0 + while not torch.allclose(current_energy, prev_energy, atol=1e-9): + prev_energy = current_energy + state = ts.lbfgs_step(state=state, model=lj_model) + current_energy = state.energy + + step += 1 + if step > 500: + raise ValueError("L-BFGS optimization did not converge") + + # Check that we actually optimized + assert step > 5 + + # Check that energy decreased for both batches + assert torch.all(state.energy < initial_state.energy), ( + "L-BFGS optimization should reduce energy for all batches" + ) + + # Check force convergence + max_force = torch.max(torch.norm(state.forces, dim=1)) + assert torch.all(max_force < 0.2), ( + f"Forces should be small after optimization, got {max_force=}" + ) + + n_ar_atoms = ar_supercell_sim_state.n_atoms + assert not torch.allclose( + state.positions[:n_ar_atoms], multi_state.positions[:n_ar_atoms] + ) + assert not torch.allclose( + state.positions[n_ar_atoms:], multi_state.positions[n_ar_atoms:] + ) + + # We are evolving identical systems + assert current_energy[0] == current_energy[1] + + @pytest.mark.parametrize( ("optimizer_fn", "expected_state_type"), [ From 790236f0ec3206212feede466dbac308794cc131 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 06:04:14 -0800 Subject: [PATCH 16/17] why exact check for floating point? --- tests/test_optimizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 33b3ac862..f5071d1f9 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -408,7 +408,7 @@ def test_unit_cell_bfgs_multi_batch( ) # We are evolving identical systems - assert current_energy[0] == current_energy[1] + assert torch.allclose(current_energy[0], current_energy[1]) @pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) @@ -554,7 +554,7 @@ def test_unit_cell_lbfgs_multi_batch( ) # We are evolving identical systems - assert current_energy[0] == current_energy[1] + assert torch.allclose(current_energy[0], current_energy[1]) @pytest.mark.parametrize( From 7c20da1aa56fd5ce11cae5453234e497afe0a1fb Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Wed, 4 Feb 2026 06:34:44 -0800 Subject: [PATCH 17/17] keep example script same --- examples/scripts/2_structural_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/2_structural_optimization.py b/examples/scripts/2_structural_optimization.py index ccf5b9019..288494ab1 100644 --- a/examples/scripts/2_structural_optimization.py +++ b/examples/scripts/2_structural_optimization.py @@ -111,7 +111,7 @@ # Run optimization for step in range(N_steps): - if step % 20 == 0: + if step % 100 == 0: print(f"Step {step}: Potential energy: {state.energy[0].item()} eV") state = ts.fire_step(state=state, model=lj_model, dt_max=0.01)