diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index ad79d4a0f..cae16fb32 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -715,6 +715,74 @@ def test_fire_vv_negative_power_branch( ) +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +@pytest.mark.parametrize("cell_filter", [None, ts.CellFilter.unit, ts.CellFilter.frechet]) +def test_fire_nan_velocities_dont_affect_other_systems( + ar_supercell_sim_state: SimState, + lj_model: ModelInterface, + fire_flavor: "FireFlavor", + cell_filter: "ts.CellFilter | None", +) -> None: + """Injecting NaN velocities into one system must not alter another's trajectory. + + Regression: _ase_fire_step used ``if nan_velocities.any()`` to skip force + transformation AND FIRE mixing for ALL systems. When the InFlightAutoBatcher + swaps in a new state (NaN velocities from fire_init), retained systems got + skipped FIRE mixing and untransformed forces. This test clones a state, + injects NaN into one copy's system 1, and verifies system 0 is identical. + """ + multi = ts.concatenate_states( + [ar_supercell_sim_state, copy.deepcopy(ar_supercell_sim_state)] + ) + + init_kwargs: dict[str, Any] = {"fire_flavor": fire_flavor} + if cell_filter is not None: + multi.cell = multi.cell * 0.85 + multi.positions = multi.positions * 0.85 + init_kwargs["cell_filter"] = cell_filter + + state = ts.fire_init(state=multi, model=lj_model, **init_kwargs) + + # Evolve 10 steps so system 0 has non-trivial FIRE state (dt, alpha, n_pos) + for _ in range(10): + state = ts.fire_step(state=state, model=lj_model) + + # Clone, then inject NaN into system 1 of one copy + state_clean = copy.deepcopy(state) + state_mixed = copy.deepcopy(state) + + sys1_atoms = state_mixed.system_idx == 1 + state_mixed.velocities[sys1_atoms] = float("nan") + if cell_filter is not None: + state_mixed.cell_velocities[1] = float("nan") + + # One step each + state_clean = ts.fire_step(state=state_clean, model=lj_model) + state_mixed = ts.fire_step(state=state_mixed, model=lj_model) + + # System 0 must be identical regardless of system 1's NaN velocities + sys0 = state_clean.system_idx == 0 + assert torch.equal(state_mixed.positions[sys0], state_clean.positions[sys0]), ( + "System 0 positions differ when system 1 has NaN velocities" + ) + assert torch.equal(state_mixed.velocities[sys0], state_clean.velocities[sys0]), ( + "System 0 velocities differ when system 1 has NaN velocities" + ) + assert state_mixed.dt[0] == state_clean.dt[0], ( + "System 0 dt differs when system 1 has NaN velocities" + ) + assert state_mixed.alpha[0] == state_clean.alpha[0], ( + "System 0 alpha differs when system 1 has NaN velocities" + ) + assert state_mixed.n_pos[0] == state_clean.n_pos[0], ( + "System 0 n_pos differs when system 1 has NaN velocities" + ) + if cell_filter is not None: + assert torch.equal(state_mixed.cell[0], state_clean.cell[0]), ( + "System 0 cell differs when system 1 has NaN velocities" + ) + + @pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) def test_unit_cell_fire_optimization( ar_supercell_sim_state: SimState, lj_model: ModelInterface, fire_flavor: FireFlavor diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 80fee98fc..e35ec1c18 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -190,14 +190,10 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # Initialize velocities if NaN nan_velocities = state.velocities.isnan().any(dim=1) if nan_velocities.any(): - state.velocities[nan_velocities] = torch.zeros_like( - state.positions[nan_velocities] - ) - if isinstance(state, CellFireState): # update velocities to zero if NaN - nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) - state.cell_velocities[nan_cell_velocities] = torch.zeros_like( - state.cell_positions[nan_cell_velocities] - ) + state.velocities[nan_velocities] = 0 + if isinstance(state, CellFireState): + nan_cell_vel = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_vel] = 0 alpha_start_system = torch.full( (n_systems,), alpha_start.item(), device=device, dtype=dtype @@ -302,18 +298,16 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 n_systems, device, dtype = state.n_systems, state.device, state.dtype - # Initialize velocities if NaN + # Per-atom NaN detection before zeroing: needed to decide whether to skip + # FIRE mixing (all NaN = first step) vs run it (partial NaN = autobatcher swap). nan_velocities = state.velocities.isnan().any(dim=1) - if nan_velocities.any(): - state.velocities[nan_velocities] = torch.zeros_like( - state.velocities[nan_velocities] - ) + state.velocities.nan_to_num_(nan=0.0) + if isinstance(state, CellFireState): + state.cell_velocities.nan_to_num_(nan=0.0) + + if nan_velocities.all(): + # First step: all NaN → zero. Use raw forces, skip FIRE mixing (matches ASE). forces = state.forces - if isinstance(state, CellFireState): - nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) - state.cell_velocities[nan_cell_velocities] = torch.zeros_like( - state.cell_velocities[nan_cell_velocities] - ) else: alpha_start_system = torch.full( (n_systems,), alpha_start.item(), device=device, dtype=dtype @@ -321,7 +315,6 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 # Transform forces for cell optimization if isinstance(state, CellFireState): - # Get deformation gradient for force transformation cur_deform_grad = cell_filters.deform_grad( state.row_vector_cell, getattr(state, "reference_row_vector_cell", state.row_vector_cell), @@ -332,7 +325,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 else: forces = state.forces - # Calculate power + # Calculate power (newly zeroed systems will have power=0 → neg_mask) system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx) if isinstance(state, CellFireState): system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2))