From 045b5c079cc2d9db22cb2941e0fc7493088dbb4d Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 16:27:30 -0700 Subject: [PATCH 1/7] added init subclass --- torch_sim/state.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torch_sim/state.py b/torch_sim/state.py index af4db6d19..d2ec83518 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,6 +6,7 @@ import copy import importlib +import typing import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Self, cast @@ -400,6 +401,33 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def __init_subclass__(cls, **kwargs) -> None: + """Enforce that all derived states cannot have tensor attributes that can also be + None. This is because torch.concatenate cannot concat between a tensor and a None. + See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + """ + # We need to use get_type_hints to correctly inspect the types + type_hints = typing.get_type_hints(cls) + for attr_name, attr_typehint in type_hints.items(): + origin = typing.get_origin(attr_typehint) + + is_union = origin is typing.Union + if not is_union and origin is not None: + # For Python 3.10+ `|` syntax, origin is types.UnionType + # We check by name to be robust against module reloading/patching issues + is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" + if is_union: + args = typing.get_args(attr_typehint) + if torch.Tensor in args and type(None) in args: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be of type 'torch.Tensor | None' because torch.cat " + "cannot concatenate between a tensor and a None. Please default " + "the tensor with dummy values and track the 'None' case." + ) + + super().__init_subclass__(**kwargs) + class DeformGradMixin: """Mixin for states that support deformation gradients.""" From b5abe8e9bc8891a95a5e654877507f5588bdd559 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 16:39:29 -0700 Subject: [PATCH 2/7] add test_in_flight_with_fire_only_converge_some_states test --- tests/test_autobatching.py | 56 +++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 898bead0d..8cb110678 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -489,7 +489,61 @@ def convergence_fn(state: ts.SimState) -> bool: if state is None: break - # run 10 steps, arbitrary number + # run 10 steps (so all states converge before the next batch) + for _ in range(10): + state = fire_update(state) + convergence_tensor = convergence_fn(state) + + assert len(all_completed_states) == len(fire_states) + + +def test_in_flight_with_fire_only_converge_some_states( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """This test is the same as the test_in_flight_with_fire above + but we only converge a few states before we trigger the auto batcher. + This can catch bugs related to merging partially converged and fully converged + states. See https://github.com/Radical-AI/torch-sim/pull/219 + """ + fire_init, fire_update = unit_cell_fire(lj_model) + + si_fire_state = fire_init(si_sim_state) + fe_fire_state = fire_init(fe_supercell_sim_state) + + fire_states = [si_fire_state, fe_fire_state] * 5 + fire_states = [state.clone() for state in fire_states] + for state in fire_states: + state.positions += torch.randn_like(state.positions) * 0.01 + + batcher = InFlightAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + # max_metric=400_000, + max_memory_scaler=600, + ) + batcher.load_states(fire_states) + + def convergence_fn(state: ts.SimState) -> bool: + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=torch.float64 + ) + max_forces = state.forces.norm(dim=1) + system_wise_max_force = system_wise_max_force.scatter_reduce( + dim=0, index=state.system_idx, src=max_forces, reduce="amax" + ) + return system_wise_max_force < 5e-1 + + all_completed_states, convergence_tensor = [], None + while True: + state, completed_states = batcher.next_batch(state, convergence_tensor) + + all_completed_states.extend(completed_states) + if state is None: + break + + # run 5 steps (so not every state can converge before the next batch) for _ in range(5): state = fire_update(state) convergence_tensor = convergence_fn(state) From 97128c9593a286545f93675ab0dab508c04f4cc0 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 16:50:30 -0700 Subject: [PATCH 3/7] tests pass --- tests/test_state.py | 13 +++++++++++++ torch_sim/optimizers.py | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index af0bda7b3..81109bf36 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -635,3 +635,16 @@ def test_deprecated_batch_properties_equal_to_new_system_properties( state.batch = new_system_idx assert torch.allclose(state.system_idx, new_system_idx) assert torch.allclose(state.batch, new_system_idx) + + +def test_derived_classes_trigger_init_subclass() -> None: + """Test that derived classes cannot have attributes that are "tensors | None".""" + + with pytest.raises(TypeError) as excinfo: + + class DerivedState(SimState): + invalid_attr: torch.Tensor | None = None + + assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str( + excinfo.value + ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 8626ff03f..9a41cd8c5 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -475,7 +475,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor | None + velocities: torch.Tensor # FIRE algorithm parameters dt: torch.Tensor @@ -972,7 +972,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor | None + cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor From 4e056070f6c5956ba9bf5eb5cf8ea722e6906367 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 17:00:36 -0700 Subject: [PATCH 4/7] maybe this is correct not sure yet. --- torch_sim/runners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 187cdd89b..8f2917da3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -538,8 +538,8 @@ def static( @dataclass class StaticState(type(state)): energy: torch.Tensor - forces: torch.Tensor | None - stress: torch.Tensor | None + forces: torch.Tensor + stress: torch.Tensor all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames From c71be9a1f4bb3327dfa1575b8df3f20841653cc9 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 19:30:33 -0700 Subject: [PATCH 5/7] update model interface to have __call__ method --- torch_sim/models/interface.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 27c032779..3955f8175 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -169,6 +169,10 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens ``` """ + @abstractmethod + def __call__(*args, **kwargs) -> dict[str, torch.Tensor]: + """Where the input is fed into the model. This is to help typecheckers.""" + def validate_model_outputs( # noqa: C901, PLR0915 model: ModelInterface, device: torch.device, dtype: torch.dtype From 56abbfd6e09ef6aa12d8a8307fa479ff51a2e8c9 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 3 Aug 2025 07:58:01 -0700 Subject: [PATCH 6/7] parameterize num_steps_per_batch for test_in_flight_with_fire --- tests/test_autobatching.py | 67 +++++++------------------------------- 1 file changed, 11 insertions(+), 56 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 8cb110678..4bce41658 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -448,10 +448,20 @@ def test_in_flight_auto_batcher_restore_order( # batcher.restore_original_order([si_sim_state]) +@pytest.mark.parametrize( + "num_steps_per_batch", + [ + 5, # At 5 steps, not every state will converge before the next batch. + # This tests the merging of partially converged states with new states + # which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219 + 10, # At 10 steps, all states will converge before the next batch + ], +) def test_in_flight_with_fire( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, + num_steps_per_batch: int, ) -> None: fire_init, fire_update = unit_cell_fire(lj_model) @@ -489,62 +499,7 @@ def convergence_fn(state: ts.SimState) -> bool: if state is None: break - # run 10 steps (so all states converge before the next batch) - for _ in range(10): - state = fire_update(state) - convergence_tensor = convergence_fn(state) - - assert len(all_completed_states) == len(fire_states) - - -def test_in_flight_with_fire_only_converge_some_states( - si_sim_state: ts.SimState, - fe_supercell_sim_state: ts.SimState, - lj_model: LennardJonesModel, -) -> None: - """This test is the same as the test_in_flight_with_fire above - but we only converge a few states before we trigger the auto batcher. - This can catch bugs related to merging partially converged and fully converged - states. See https://github.com/Radical-AI/torch-sim/pull/219 - """ - fire_init, fire_update = unit_cell_fire(lj_model) - - si_fire_state = fire_init(si_sim_state) - fe_fire_state = fire_init(fe_supercell_sim_state) - - fire_states = [si_fire_state, fe_fire_state] * 5 - fire_states = [state.clone() for state in fire_states] - for state in fire_states: - state.positions += torch.randn_like(state.positions) * 0.01 - - batcher = InFlightAutoBatcher( - model=lj_model, - memory_scales_with="n_atoms", - # max_metric=400_000, - max_memory_scaler=600, - ) - batcher.load_states(fire_states) - - def convergence_fn(state: ts.SimState) -> bool: - system_wise_max_force = torch.zeros( - state.n_systems, device=state.device, dtype=torch.float64 - ) - max_forces = state.forces.norm(dim=1) - system_wise_max_force = system_wise_max_force.scatter_reduce( - dim=0, index=state.system_idx, src=max_forces, reduce="amax" - ) - return system_wise_max_force < 5e-1 - - all_completed_states, convergence_tensor = [], None - while True: - state, completed_states = batcher.next_batch(state, convergence_tensor) - - all_completed_states.extend(completed_states) - if state is None: - break - - # run 5 steps (so not every state can converge before the next batch) - for _ in range(5): + for _ in range(num_steps_per_batch): state = fire_update(state) convergence_tensor = convergence_fn(state) From 3d23f7da01c8ed2c20ae81fbc3d933962d16a951 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 7 Aug 2025 20:23:08 -0700 Subject: [PATCH 7/7] rm call --- torch_sim/models/interface.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 3955f8175..27c032779 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -169,10 +169,6 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens ``` """ - @abstractmethod - def __call__(*args, **kwargs) -> dict[str, torch.Tensor]: - """Where the input is fed into the model. This is to help typecheckers.""" - def validate_model_outputs( # noqa: C901, PLR0915 model: ModelInterface, device: torch.device, dtype: torch.dtype