diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index e51ead764..3ab10533c 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1,4 +1,5 @@ import copy +import warnings from collections.abc import Callable from functools import partial from typing import Any, get_args @@ -1086,6 +1087,57 @@ def test_frechet_cell_fire_optimization( ) +def test_frechet_lbfgs_clamps_extreme_deformation( + ar_supercell_sim_state: SimState, lj_model: ModelInterface +) -> None: + """LBFGS + Frechet cell filter clamps extreme log-space deformation. + + Injects extreme cell_positions so that cell_positions / cell_factor > 2.0, + then verifies: (1) the clamp warning fires, (2) the log-space deformation + is bounded after the step, and (3) positions/cell remain finite. + """ + state = ts.lbfgs_init( + state=ar_supercell_sim_state, + model=lj_model, + cell_filter=ts.CellFilter.frechet, + ) + + # Inject extreme cell_positions: log-deform = cell_positions/cell_factor = 10 + # This far exceeds the MAX_LOG_DEFORM=2.0 clamp threshold. + state.cell_positions = state.cell_positions + 10.0 * state.cell_factor.view( + -1, 1, 1 + ) * torch.eye(3, device=state.cell.device, dtype=state.cell.dtype).unsqueeze(0) + + log_deform_before = ( + (state.cell_positions / state.cell_factor.view(-1, 1, 1)).abs().max().item() + ) + assert log_deform_before > 5.0, "Setup: log-deform should be extreme before step" + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + state = ts.lbfgs_step(state=state, model=lj_model) + + # 1. Log-space deformation must be bounded after the clamp + log_deform_after = ( + (state.cell_positions / state.cell_factor.view(-1, 1, 1)).abs().max().item() + ) + assert log_deform_after <= 2.5, ( + f"Log-space deformation should be clamped to ~2.0, got {log_deform_after:.2f}" + ) + + # 2. Positions and cell must remain finite + assert not torch.isnan(state.positions).any(), "Positions contain NaN" + assert not torch.isinf(state.positions).any(), "Positions contain Inf" + assert not torch.isnan(state.cell).any(), "Cell contains NaN" + assert not torch.isinf(state.cell).any(), "Cell contains Inf" + + # 3. The clamp warning must have fired + clamp_warnings = [ + warn for warn in caught_warnings if "Clamping log-space" in str(warn.message) + ] + assert len(clamp_warnings) > 0, "Expected clamping warning but none was emitted" + + @pytest.mark.parametrize( "filter_func", [None, ts.CellFilter.unit, ts.CellFilter.frechet], diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index 5cccf96ca..c3a344cf3 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -20,7 +20,11 @@ import torch_sim as ts from torch_sim.optimizers import cell_filters -from torch_sim.optimizers.cell_filters import CellBFGSState, frechet_cell_filter_init +from torch_sim.optimizers.cell_filters import ( + CellBFGSState, + _clamp_deform_grad_log, + frechet_cell_filter_init, +) from torch_sim.state import SimState @@ -507,6 +511,10 @@ def bfgs_step( # noqa: C901, PLR0915 # Frechet: deform_grad = exp(cell_positions / cell_factor) cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3] + deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log( + deform_grad_log_new, cell_positions_new, cell_factor_reshaped + ) + state.cell_positions = cell_positions_new # [S, 3, 3] deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3] else: # UnitCell: deform_grad = cell_positions / cell_factor diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index b501a7aa8..8a58f8d40 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -6,6 +6,7 @@ during optimization. """ +import warnings from collections.abc import Callable from dataclasses import dataclass, field from enum import StrEnum @@ -19,6 +20,46 @@ from torch_sim.state import SimState +MAX_LOG_DEFORM = 2.0 + + +def _clamp_deform_grad_log( + deform_grad_log: torch.Tensor, + cell_positions: torch.Tensor, + cell_factor_reshaped: torch.Tensor, + *, + max_log_deform: float = MAX_LOG_DEFORM, +) -> tuple[torch.Tensor, torch.Tensor]: + """Clamp log-space deformation gradient to prevent matrix_exp overflow. + + When cell_positions grow unbounded (from diverging structures or extreme + steps), matrix_exp overflows to Inf/NaN. This clamps the log-space values + and writes back the clamped cell_positions so they don't re-accumulate. + + Args: + deform_grad_log: Log of the deformation gradient, shape (S, 3, 3). + cell_positions: Current cell positions in log space, shape (S, 3, 3). + cell_factor_reshaped: Cell factor broadcast to (S, 1, 1). + max_log_deform: Maximum absolute value for log-space entries. + + Returns: + Tuple of (clamped deform_grad_log, clamped cell_positions). + """ + exceeds = deform_grad_log.abs() > max_log_deform + if exceeds.any(): + n_clamped = int(exceeds.any(dim=(-2, -1)).sum().item()) + warnings.warn( + f"Clamping log-space deformation gradient for {n_clamped} " + f"system(s) to [-{max_log_deform}, {max_log_deform}] " + f"(max |log(F)| = {deform_grad_log.abs().max().item():.2f}). " + f"This prevents matrix_exp overflow from diverging cell optimization.", + stacklevel=3, + ) + deform_grad_log = deform_grad_log.clamp(-max_log_deform, max_log_deform) + cell_positions = deform_grad_log * cell_factor_reshaped + return deform_grad_log, cell_positions + + def _setup_cell_factor( state: SimState, cell_factor: float | torch.Tensor | None, @@ -294,6 +335,9 @@ def frechet_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) # Convert from log space to deformation gradient cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) deform_grad_log_new = cell_positions_new / cell_factor_reshaped + deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log( + deform_grad_log_new, cell_positions_new, cell_factor_reshaped + ) deform_grad_new = torch.matrix_exp(deform_grad_log_new) # Update cell from new deformation gradient @@ -336,6 +380,10 @@ def compute_cell_forces[T: AnyCellState]( deform_grad_log = tsm.matrix_log_33( cur_deform_grad, sim_dtype=cur_deform_grad.dtype ) + # Clamp to the same limit used in lbfgs_step to prevent NaN from + # propagating into expm_frechet. Systems hitting the clamp have + # diverging cells; their cell forces will be approximate but finite. + deform_grad_log = deform_grad_log.clamp(-MAX_LOG_DEFORM, MAX_LOG_DEFORM) frechet_method = getattr(state, "frechet_method", None) cell_forces = _frechet_cell_forces( deform_grad_log, ucf_cell_grad, frechet_method=frechet_method diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index e35ec1c18..8efcb3a7b 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -8,6 +8,7 @@ import torch_sim.math as tsm from torch_sim._duecredit import dcite from torch_sim.optimizers import CellFireState, cell_filters +from torch_sim.optimizers.cell_filters import _clamp_deform_grad_log from torch_sim.state import SimState @@ -428,6 +429,10 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 if is_frechet: # Frechet: convert from log space to deformation gradient cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1) deform_grad_log_new = cell_positions_new / cell_factor_reshaped + deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log( + deform_grad_log_new, cell_positions_new, cell_factor_reshaped + ) + state.cell_positions = cell_positions_new deform_grad_new = torch.matrix_exp(deform_grad_log_new) else: # Unit cell: positions are scaled deformation gradient cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 0221f92d7..f8413399c 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -18,6 +18,7 @@ import torch_sim as ts from torch_sim.optimizers.cell_filters import ( CellLBFGSState, + _clamp_deform_grad_log, compute_cell_forces, deform_grad, frechet_cell_filter_init, @@ -486,7 +487,6 @@ def lbfgs_step( # noqa: PLR0915, C901 # Apply cell step dr_cell = step_cell # [S, 3, 3] cell_positions_new = state.cell_positions + dr_cell # [S, 3, 3] - state.cell_positions = cell_positions_new # [S, 3, 3] # Determine if Frechet filter init_fn, _step_fn = state.cell_filter @@ -496,12 +496,17 @@ def lbfgs_step( # noqa: PLR0915, C901 # Frechet: deform_grad = exp(cell_positions / cell_factor) cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1) deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3] + deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log( + deform_grad_log_new, cell_positions_new, cell_factor_reshaped + ) deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3] else: # UnitCell: deform_grad = cell_positions / cell_factor cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1) deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3] + state.cell_positions = cell_positions_new # [S, 3, 3] + # Update cell: new_cell = reference_cell @ deform_grad^T # Use set_constrained_cell to apply cell constraints (e.g. FixSymmetry) new_col_vector_cell = torch.bmm(