Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,20 @@ def test_in_flight_auto_batcher_restore_order(
# batcher.restore_original_order([si_sim_state])


@pytest.mark.parametrize(
"num_steps_per_batch",
[
5, # At 5 steps, not every state will converge before the next batch.
# This tests the merging of partially converged states with new states
# which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219
10, # At 10 steps, all states will converge before the next batch
],
)
def test_in_flight_with_fire(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
num_steps_per_batch: int,
) -> None:
fire_init, fire_update = unit_cell_fire(lj_model)

Expand Down Expand Up @@ -489,8 +499,7 @@ def convergence_fn(state: ts.SimState) -> bool:
if state is None:
break

# run 10 steps, arbitrary number
for _ in range(5):
for _ in range(num_steps_per_batch):
state = fire_update(state)
convergence_tensor = convergence_fn(state)

Expand Down
13 changes: 13 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,16 @@ def test_deprecated_batch_properties_equal_to_new_system_properties(
state.batch = new_system_idx
assert torch.allclose(state.system_idx, new_system_idx)
assert torch.allclose(state.batch, new_system_idx)


def test_derived_classes_trigger_init_subclass() -> None:
"""Test that derived classes cannot have attributes that are "tensors | None"."""

with pytest.raises(TypeError) as excinfo:

class DerivedState(SimState):
invalid_attr: torch.Tensor | None = None

assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str(
excinfo.value
)
4 changes: 2 additions & 2 deletions torch_sim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ class FireState(SimState):
# Required attributes not in SimState
forces: torch.Tensor
energy: torch.Tensor
velocities: torch.Tensor | None
velocities: torch.Tensor

# FIRE algorithm parameters
dt: torch.Tensor
Expand Down Expand Up @@ -972,7 +972,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin):

# Cell attributes
cell_positions: torch.Tensor
cell_velocities: torch.Tensor | None
cell_velocities: torch.Tensor
cell_forces: torch.Tensor
cell_masses: torch.Tensor

Expand Down
4 changes: 2 additions & 2 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ def static(
@dataclass
class StaticState(type(state)):
energy: torch.Tensor
forces: torch.Tensor | None
stress: torch.Tensor | None
forces: torch.Tensor
stress: torch.Tensor

all_props: list[dict[str, torch.Tensor]] = []
og_filenames = trajectory_reporter.filenames
Expand Down
28 changes: 28 additions & 0 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import importlib
import typing
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Self, cast
Expand Down Expand Up @@ -400,6 +401,33 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) ->

return _slice_state(self, system_indices)

def __init_subclass__(cls, **kwargs) -> None:
"""Enforce that all derived states cannot have tensor attributes that can also be
None. This is because torch.concatenate cannot concat between a tensor and a None.
See https://github.com/Radical-AI/torch-sim/pull/219 for more details.
"""
# We need to use get_type_hints to correctly inspect the types
type_hints = typing.get_type_hints(cls)
for attr_name, attr_typehint in type_hints.items():
origin = typing.get_origin(attr_typehint)

is_union = origin is typing.Union
if not is_union and origin is not None:
# For Python 3.10+ `|` syntax, origin is types.UnionType
# We check by name to be robust against module reloading/patching issues
is_union = origin.__module__ == "types" and origin.__name__ == "UnionType"
if is_union:
args = typing.get_args(attr_typehint)
if torch.Tensor in args and type(None) in args:
raise TypeError(
f"Attribute '{attr_name}' in class '{cls.__name__}' is not "
"allowed to be of type 'torch.Tensor | None' because torch.cat "
"cannot concatenate between a tensor and a None. Please default "
"the tensor with dummy values and track the 'None' case."
)

super().__init_subclass__(**kwargs)


class DeformGradMixin:
"""Mixin for states that support deformation gradients."""
Expand Down