Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,74 @@ def test_fire_vv_negative_power_branch(
)


@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor))
@pytest.mark.parametrize("cell_filter", [None, ts.CellFilter.unit, ts.CellFilter.frechet])
def test_fire_nan_velocities_dont_affect_other_systems(
ar_supercell_sim_state: SimState,
lj_model: ModelInterface,
fire_flavor: "FireFlavor",
cell_filter: "ts.CellFilter | None",
) -> None:
"""Injecting NaN velocities into one system must not alter another's trajectory.

Regression: _ase_fire_step used ``if nan_velocities.any()`` to skip force
transformation AND FIRE mixing for ALL systems. When the InFlightAutoBatcher
swaps in a new state (NaN velocities from fire_init), retained systems got
skipped FIRE mixing and untransformed forces. This test clones a state,
injects NaN into one copy's system 1, and verifies system 0 is identical.
"""
multi = ts.concatenate_states(
[ar_supercell_sim_state, copy.deepcopy(ar_supercell_sim_state)]
)

init_kwargs: dict[str, Any] = {"fire_flavor": fire_flavor}
if cell_filter is not None:
multi.cell = multi.cell * 0.85
multi.positions = multi.positions * 0.85
init_kwargs["cell_filter"] = cell_filter

state = ts.fire_init(state=multi, model=lj_model, **init_kwargs)

# Evolve 10 steps so system 0 has non-trivial FIRE state (dt, alpha, n_pos)
for _ in range(10):
state = ts.fire_step(state=state, model=lj_model)

# Clone, then inject NaN into system 1 of one copy
state_clean = copy.deepcopy(state)
state_mixed = copy.deepcopy(state)

sys1_atoms = state_mixed.system_idx == 1
state_mixed.velocities[sys1_atoms] = float("nan")
if cell_filter is not None:
state_mixed.cell_velocities[1] = float("nan")

# One step each
state_clean = ts.fire_step(state=state_clean, model=lj_model)
state_mixed = ts.fire_step(state=state_mixed, model=lj_model)

# System 0 must be identical regardless of system 1's NaN velocities
sys0 = state_clean.system_idx == 0
assert torch.equal(state_mixed.positions[sys0], state_clean.positions[sys0]), (
"System 0 positions differ when system 1 has NaN velocities"
)
assert torch.equal(state_mixed.velocities[sys0], state_clean.velocities[sys0]), (
"System 0 velocities differ when system 1 has NaN velocities"
)
assert state_mixed.dt[0] == state_clean.dt[0], (
"System 0 dt differs when system 1 has NaN velocities"
)
assert state_mixed.alpha[0] == state_clean.alpha[0], (
"System 0 alpha differs when system 1 has NaN velocities"
)
assert state_mixed.n_pos[0] == state_clean.n_pos[0], (
"System 0 n_pos differs when system 1 has NaN velocities"
)
if cell_filter is not None:
assert torch.equal(state_mixed.cell[0], state_clean.cell[0]), (
"System 0 cell differs when system 1 has NaN velocities"
)


@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor))
def test_unit_cell_fire_optimization(
ar_supercell_sim_state: SimState, lj_model: ModelInterface, fire_flavor: FireFlavor
Expand Down
33 changes: 13 additions & 20 deletions torch_sim/optimizers/fire.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,10 @@ def _vv_fire_step[T: "FireState | CellFireState"](
# Initialize velocities if NaN
nan_velocities = state.velocities.isnan().any(dim=1)
if nan_velocities.any():
state.velocities[nan_velocities] = torch.zeros_like(
state.positions[nan_velocities]
)
if isinstance(state, CellFireState): # update velocities to zero if NaN
nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2))
state.cell_velocities[nan_cell_velocities] = torch.zeros_like(
state.cell_positions[nan_cell_velocities]
)
state.velocities[nan_velocities] = 0
if isinstance(state, CellFireState):
nan_cell_vel = state.cell_velocities.isnan().any(dim=(1, 2))
state.cell_velocities[nan_cell_vel] = 0

alpha_start_system = torch.full(
(n_systems,), alpha_start.item(), device=device, dtype=dtype
Expand Down Expand Up @@ -302,26 +298,23 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915

n_systems, device, dtype = state.n_systems, state.device, state.dtype

# Initialize velocities if NaN
# Per-atom NaN detection before zeroing: needed to decide whether to skip
# FIRE mixing (all NaN = first step) vs run it (partial NaN = autobatcher swap).
nan_velocities = state.velocities.isnan().any(dim=1)
if nan_velocities.any():
state.velocities[nan_velocities] = torch.zeros_like(
state.velocities[nan_velocities]
)
state.velocities.nan_to_num_(nan=0.0)
if isinstance(state, CellFireState):
state.cell_velocities.nan_to_num_(nan=0.0)

if nan_velocities.all():
# First step: all NaN → zero. Use raw forces, skip FIRE mixing (matches ASE).
forces = state.forces
if isinstance(state, CellFireState):
nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2))
state.cell_velocities[nan_cell_velocities] = torch.zeros_like(
state.cell_velocities[nan_cell_velocities]
)
else:
alpha_start_system = torch.full(
(n_systems,), alpha_start.item(), device=device, dtype=dtype
)

# Transform forces for cell optimization
if isinstance(state, CellFireState):
# Get deformation gradient for force transformation
cur_deform_grad = cell_filters.deform_grad(
state.row_vector_cell,
getattr(state, "reference_row_vector_cell", state.row_vector_cell),
Expand All @@ -332,7 +325,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915
else:
forces = state.forces

# Calculate power
# Calculate power (newly zeroed systems will have power=0 → neg_mask)
system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx)
if isinstance(state, CellFireState):
system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2))
Expand Down
Loading