From caa7482897d94f884eedef83c0d7c419a04eabd8 Mon Sep 17 00:00:00 2001 From: janosh Date: Wed, 4 Mar 2026 16:50:47 +0000 Subject: [PATCH 1/2] Fix ASE FIRE NaN velocity handling corrupting retained systems in batched optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The NaN-velocity sentinel in `_ase_fire_step` used an `if/else` that skipped force transformation AND FIRE mixing for ALL systems whenever ANY system had NaN velocities. When the InFlightAutoBatcher swaps in a new state (NaN velocities from fire_init), retained systems got their FIRE state (dt, alpha, n_pos, velocity mixing) skipped entirely. Proven with a clone-and-compare test: on main, system 0's positions, dt, alpha, and n_pos all differ when system 1 has NaN velocities injected. With the fix, all diffs are exactly 0. Fix: decouple NaN zeroing from the FIRE logic branch. Only skip FIRE mixing when ALL velocities are NaN (first step, matching ASE behavior). When a subset has NaN (autobatcher swap), zero them and proceed with normal FIRE logic — newly zeroed systems naturally get power=0 → negative mask → dt decrease and velocity reset. Note: this is separate from Killian's FixSymmetry NaN error, which also reproduces with LBFGS and appears to be an autobatcher-level issue. --- tests/test_optimizers.py | 72 ++++++++++++++++++++++++++++++++++++ torch_sim/optimizers/fire.py | 23 +++++++++--- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index ad79d4a0f..652b1e5a3 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -715,6 +715,78 @@ def test_fire_vv_negative_power_branch( ) +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +@pytest.mark.parametrize("cell_filter", [None, ts.CellFilter.frechet]) +def test_fire_nan_velocities_dont_affect_other_systems( + ar_supercell_sim_state: SimState, + lj_model: ModelInterface, + fire_flavor: "FireFlavor", + cell_filter: "ts.CellFilter | None", +) -> None: + """Injecting NaN velocities into one system must not alter another's trajectory. + + Regression: _ase_fire_step used ``if nan_velocities.any()`` to skip force + transformation AND FIRE mixing for ALL systems. When the InFlightAutoBatcher + swaps in a new state (NaN velocities from fire_init), retained systems got + skipped FIRE mixing and untransformed forces. This test clones a state, + injects NaN into one copy's system 1, and verifies system 0 is identical. + """ + state_a = copy.deepcopy(ar_supercell_sim_state) + state_b = copy.deepcopy(ar_supercell_sim_state) + multi = ts.concatenate_states([state_a, state_b]) + + init_kwargs: dict = {"fire_flavor": fire_flavor} + if cell_filter is not None: + multi.cell = multi.cell * 0.85 + multi.positions = multi.positions * 0.85 + init_kwargs["cell_filter"] = cell_filter + + state = ts.fire_init(state=multi, model=lj_model, **init_kwargs) + + # Evolve 10 steps so system 0 has non-trivial FIRE state (dt, alpha, n_pos) + for _ in range(10): + state = ts.fire_step(state=state, model=lj_model) + + # Clone, then inject NaN into system 1 of one copy + state_clean = copy.deepcopy(state) + state_mixed = copy.deepcopy(state) + + sys1_atoms = state_mixed.system_idx == 1 + state_mixed.velocities[sys1_atoms] = torch.full_like( + state_mixed.velocities[sys1_atoms], torch.nan + ) + if cell_filter is not None: + state_mixed.cell_velocities[1] = torch.full_like( + state_mixed.cell_velocities[1], torch.nan + ) + + # One step each + state_clean = ts.fire_step(state=state_clean, model=lj_model) + state_mixed = ts.fire_step(state=state_mixed, model=lj_model) + + # System 0 must be identical regardless of system 1's NaN velocities + sys0 = state_clean.system_idx == 0 + assert torch.equal(state_mixed.positions[sys0], state_clean.positions[sys0]), ( + "System 0 positions differ when system 1 has NaN velocities" + ) + assert torch.equal(state_mixed.velocities[sys0], state_clean.velocities[sys0]), ( + "System 0 velocities differ when system 1 has NaN velocities" + ) + assert state_mixed.dt[0] == state_clean.dt[0], ( + "System 0 dt differs when system 1 has NaN velocities" + ) + assert state_mixed.alpha[0] == state_clean.alpha[0], ( + "System 0 alpha differs when system 1 has NaN velocities" + ) + assert state_mixed.n_pos[0] == state_clean.n_pos[0], ( + "System 0 n_pos differs when system 1 has NaN velocities" + ) + if cell_filter is not None: + assert torch.equal(state_mixed.cell[0], state_clean.cell[0]), ( + "System 0 cell differs when system 1 has NaN velocities" + ) + + @pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) def test_unit_cell_fire_optimization( ar_supercell_sim_state: SimState, lj_model: ModelInterface, fire_flavor: FireFlavor diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 80fee98fc..00f15731d 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -302,18 +302,32 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 n_systems, device, dtype = state.n_systems, state.device, state.dtype - # Initialize velocities if NaN + # Zero out NaN velocities (sentinel from fire_init for first-step detection). + # Track per-system so we can skip FIRE mixing only for newly initialized + # systems while still transforming forces for retained systems. Without this, + # autobatcher state swaps cause all systems to get untransformed forces and + # no velocity mixing, destabilizing cell optimization with FixSymmetry. nan_velocities = state.velocities.isnan().any(dim=1) - if nan_velocities.any(): + has_nan = nan_velocities.any() + if has_nan: state.velocities[nan_velocities] = torch.zeros_like( state.velocities[nan_velocities] ) - forces = state.forces if isinstance(state, CellFireState): nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) state.cell_velocities[nan_cell_velocities] = torch.zeros_like( state.cell_velocities[nan_cell_velocities] ) + + # Skip FIRE mixing entirely when ALL systems are new (first step, matches ASE). + # When only SOME atoms have NaN (autobatcher added new systems), we must still + # run force transformation and FIRE mixing for retained systems. + all_nan = has_nan and nan_velocities.all() + + if all_nan: + # First step: all velocities were NaN → zero. Use raw forces and skip + # FIRE mixing (ASE-compatible: ASE also skips mixing when v is None). + forces = state.forces else: alpha_start_system = torch.full( (n_systems,), alpha_start.item(), device=device, dtype=dtype @@ -321,7 +335,6 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 # Transform forces for cell optimization if isinstance(state, CellFireState): - # Get deformation gradient for force transformation cur_deform_grad = cell_filters.deform_grad( state.row_vector_cell, getattr(state, "reference_row_vector_cell", state.row_vector_cell), @@ -332,7 +345,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 else: forces = state.forces - # Calculate power + # Calculate power (newly zeroed systems will have power=0 → neg_mask) system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx) if isinstance(state, CellFireState): system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) From a806bbfe244d87d0f6deae1051d77e0992352784 Mon Sep 17 00:00:00 2001 From: janosh Date: Wed, 4 Mar 2026 17:14:22 +0000 Subject: [PATCH 2/2] Simplify NaN velocity handling and improve test coverage Use nan_to_num_() instead of conditional masked zeroing in _ase_fire_step, fix _vv_fire_step using zeros_like(positions) instead of velocities, and add unit cell filter to NaN isolation test parametrization. --- tests/test_optimizers.py | 18 ++++++---------- torch_sim/optimizers/fire.py | 42 ++++++++++-------------------------- 2 files changed, 18 insertions(+), 42 deletions(-) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 652b1e5a3..cae16fb32 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -716,7 +716,7 @@ def test_fire_vv_negative_power_branch( @pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) -@pytest.mark.parametrize("cell_filter", [None, ts.CellFilter.frechet]) +@pytest.mark.parametrize("cell_filter", [None, ts.CellFilter.unit, ts.CellFilter.frechet]) def test_fire_nan_velocities_dont_affect_other_systems( ar_supercell_sim_state: SimState, lj_model: ModelInterface, @@ -731,11 +731,11 @@ def test_fire_nan_velocities_dont_affect_other_systems( skipped FIRE mixing and untransformed forces. This test clones a state, injects NaN into one copy's system 1, and verifies system 0 is identical. """ - state_a = copy.deepcopy(ar_supercell_sim_state) - state_b = copy.deepcopy(ar_supercell_sim_state) - multi = ts.concatenate_states([state_a, state_b]) + multi = ts.concatenate_states( + [ar_supercell_sim_state, copy.deepcopy(ar_supercell_sim_state)] + ) - init_kwargs: dict = {"fire_flavor": fire_flavor} + init_kwargs: dict[str, Any] = {"fire_flavor": fire_flavor} if cell_filter is not None: multi.cell = multi.cell * 0.85 multi.positions = multi.positions * 0.85 @@ -752,13 +752,9 @@ def test_fire_nan_velocities_dont_affect_other_systems( state_mixed = copy.deepcopy(state) sys1_atoms = state_mixed.system_idx == 1 - state_mixed.velocities[sys1_atoms] = torch.full_like( - state_mixed.velocities[sys1_atoms], torch.nan - ) + state_mixed.velocities[sys1_atoms] = float("nan") if cell_filter is not None: - state_mixed.cell_velocities[1] = torch.full_like( - state_mixed.cell_velocities[1], torch.nan - ) + state_mixed.cell_velocities[1] = float("nan") # One step each state_clean = ts.fire_step(state=state_clean, model=lj_model) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 00f15731d..e35ec1c18 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -190,14 +190,10 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # Initialize velocities if NaN nan_velocities = state.velocities.isnan().any(dim=1) if nan_velocities.any(): - state.velocities[nan_velocities] = torch.zeros_like( - state.positions[nan_velocities] - ) - if isinstance(state, CellFireState): # update velocities to zero if NaN - nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) - state.cell_velocities[nan_cell_velocities] = torch.zeros_like( - state.cell_positions[nan_cell_velocities] - ) + state.velocities[nan_velocities] = 0 + if isinstance(state, CellFireState): + nan_cell_vel = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_vel] = 0 alpha_start_system = torch.full( (n_systems,), alpha_start.item(), device=device, dtype=dtype @@ -302,31 +298,15 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 n_systems, device, dtype = state.n_systems, state.device, state.dtype - # Zero out NaN velocities (sentinel from fire_init for first-step detection). - # Track per-system so we can skip FIRE mixing only for newly initialized - # systems while still transforming forces for retained systems. Without this, - # autobatcher state swaps cause all systems to get untransformed forces and - # no velocity mixing, destabilizing cell optimization with FixSymmetry. + # Per-atom NaN detection before zeroing: needed to decide whether to skip + # FIRE mixing (all NaN = first step) vs run it (partial NaN = autobatcher swap). nan_velocities = state.velocities.isnan().any(dim=1) - has_nan = nan_velocities.any() - if has_nan: - state.velocities[nan_velocities] = torch.zeros_like( - state.velocities[nan_velocities] - ) - if isinstance(state, CellFireState): - nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) - state.cell_velocities[nan_cell_velocities] = torch.zeros_like( - state.cell_velocities[nan_cell_velocities] - ) - - # Skip FIRE mixing entirely when ALL systems are new (first step, matches ASE). - # When only SOME atoms have NaN (autobatcher added new systems), we must still - # run force transformation and FIRE mixing for retained systems. - all_nan = has_nan and nan_velocities.all() + state.velocities.nan_to_num_(nan=0.0) + if isinstance(state, CellFireState): + state.cell_velocities.nan_to_num_(nan=0.0) - if all_nan: - # First step: all velocities were NaN → zero. Use raw forces and skip - # FIRE mixing (ASE-compatible: ASE also skips mixing when v is None). + if nan_velocities.all(): + # First step: all NaN → zero. Use raw forces, skip FIRE mixing (matches ASE). forces = state.forces else: alpha_start_system = torch.full(