diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 898bead0d..4bce41658 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -448,10 +448,20 @@ def test_in_flight_auto_batcher_restore_order( # batcher.restore_original_order([si_sim_state]) +@pytest.mark.parametrize( + "num_steps_per_batch", + [ + 5, # At 5 steps, not every state will converge before the next batch. + # This tests the merging of partially converged states with new states + # which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219 + 10, # At 10 steps, all states will converge before the next batch + ], +) def test_in_flight_with_fire( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, + num_steps_per_batch: int, ) -> None: fire_init, fire_update = unit_cell_fire(lj_model) @@ -489,8 +499,7 @@ def convergence_fn(state: ts.SimState) -> bool: if state is None: break - # run 10 steps, arbitrary number - for _ in range(5): + for _ in range(num_steps_per_batch): state = fire_update(state) convergence_tensor = convergence_fn(state) diff --git a/tests/test_state.py b/tests/test_state.py index af0bda7b3..81109bf36 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -635,3 +635,16 @@ def test_deprecated_batch_properties_equal_to_new_system_properties( state.batch = new_system_idx assert torch.allclose(state.system_idx, new_system_idx) assert torch.allclose(state.batch, new_system_idx) + + +def test_derived_classes_trigger_init_subclass() -> None: + """Test that derived classes cannot have attributes that are "tensors | None".""" + + with pytest.raises(TypeError) as excinfo: + + class DerivedState(SimState): + invalid_attr: torch.Tensor | None = None + + assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str( + excinfo.value + ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 8626ff03f..9a41cd8c5 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -475,7 +475,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor | None + velocities: torch.Tensor # FIRE algorithm parameters dt: torch.Tensor @@ -972,7 +972,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor | None + cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 187cdd89b..8f2917da3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -538,8 +538,8 @@ def static( @dataclass class StaticState(type(state)): energy: torch.Tensor - forces: torch.Tensor | None - stress: torch.Tensor | None + forces: torch.Tensor + stress: torch.Tensor all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames diff --git a/torch_sim/state.py b/torch_sim/state.py index af4db6d19..d2ec83518 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,6 +6,7 @@ import copy import importlib +import typing import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Self, cast @@ -400,6 +401,33 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def __init_subclass__(cls, **kwargs) -> None: + """Enforce that all derived states cannot have tensor attributes that can also be + None. This is because torch.concatenate cannot concat between a tensor and a None. + See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + """ + # We need to use get_type_hints to correctly inspect the types + type_hints = typing.get_type_hints(cls) + for attr_name, attr_typehint in type_hints.items(): + origin = typing.get_origin(attr_typehint) + + is_union = origin is typing.Union + if not is_union and origin is not None: + # For Python 3.10+ `|` syntax, origin is types.UnionType + # We check by name to be robust against module reloading/patching issues + is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" + if is_union: + args = typing.get_args(attr_typehint) + if torch.Tensor in args and type(None) in args: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be of type 'torch.Tensor | None' because torch.cat " + "cannot concatenate between a tensor and a None. Please default " + "the tensor with dummy values and track the 'None' case." + ) + + super().__init_subclass__(**kwargs) + class DeformGradMixin: """Mixin for states that support deformation gradients."""