From aee034c338ae83504700599bf9de88e34aae6456 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 22 Nov 2025 19:51:53 -0800 Subject: [PATCH] v1 of attributes add a few tests and rename to extra_x_attributes get value strict static state can now handle arbitrary things update model interface update validate model outputs proper runner code handle deprecated model attributes fix state init revert test_state wrote tests for state more fixes more fixes better set attribute validation more tests more progress towrads better attributes assert we are only concatenating states with the same attribute names fix concat state --- tests/test_state.py | 220 ++++++++++++++++++++++++++++++---- torch_sim/models/interface.py | 131 ++++++++++++-------- torch_sim/runners.py | 46 +++---- torch_sim/state.py | 172 ++++++++++++++++++++++++-- 4 files changed, 452 insertions(+), 117 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 1b2936681..ab8e42bed 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -23,14 +23,49 @@ from pymatgen.core import Structure -def test_get_attrs_for_scope(si_sim_state: SimState) -> None: +@pytest.fixture +def si_sim_state_extra(si_sim_state: SimState) -> SimState: + """Create a basic state from si_structure with extra attributes.""" + si_sim_state.set( + "charge", + torch.tensor([3.0] * si_sim_state.n_atoms, device=si_sim_state.device), + "atom", + ) + si_sim_state.set( + "energy", + torch.tensor([3.0], device=si_sim_state.device), + "system", + ) + si_sim_state.set( + "max_steps", + 100, + "global", + ) + return si_sim_state + + +@pytest.fixture +def si_double_sim_state_extra(si_sim_state_extra: SimState) -> SimState: + return ts.concatenate_states( + [si_sim_state_extra, si_sim_state_extra], + device=si_sim_state_extra.device, + ) + + +def test_get_attrs_for_scope(si_sim_state_extra: SimState) -> None: """Test getting attributes for a scope.""" - per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) - assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} - per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs) == {"cell"} - global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) - assert set(global_attrs) == {"pbc"} + per_atom_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "per-atom")) + assert set(per_atom_attrs) == { + "positions", + "masses", + "atomic_numbers", + "system_idx", + "charge", + } + per_system_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "per-system")) + assert set(per_system_attrs) == {"cell", "energy"} + global_attrs = dict(get_attrs_for_scope(si_sim_state_extra, "global")) + assert set(global_attrs) == {"pbc", "max_steps"} def test_all_attributes_must_be_specified_in_scopes() -> None: @@ -66,19 +101,32 @@ class ChildState(SimState): assert "duplicated_attribute" in str(exc_info.value) -def test_slice_substate(si_double_sim_state: SimState, si_sim_state: SimState) -> None: +def test_slice_substate( + si_double_sim_state_extra: SimState, si_sim_state_extra: SimState +) -> None: """Test slicing a substate from the SimState.""" for system_index in range(2): - substate = _slice_state(si_double_sim_state, [system_index]) + substate = _slice_state(si_double_sim_state_extra, [system_index]) assert isinstance(substate, SimState) assert substate.positions.shape == (8, 3) assert substate.masses.shape == (8,) assert substate.cell.shape == (1, 3, 3) - assert torch.allclose(substate.positions, si_sim_state.positions) - assert torch.allclose(substate.masses, si_sim_state.masses) - assert torch.allclose(substate.cell, si_sim_state.cell) - assert torch.allclose(substate.atomic_numbers, si_sim_state.atomic_numbers) + assert substate.get_strict("charge").shape == (8,) + assert substate.get_strict("energy").shape == (1,) + assert torch.allclose(substate.positions, si_sim_state_extra.positions) + assert torch.allclose(substate.masses, si_sim_state_extra.masses) + assert torch.allclose(substate.cell, si_sim_state_extra.cell) + assert torch.allclose(substate.atomic_numbers, si_sim_state_extra.atomic_numbers) assert torch.allclose(substate.system_idx, torch.zeros_like(substate.system_idx)) + assert torch.allclose( + substate.get_strict("charge"), si_sim_state_extra.get_strict("charge") + ) + assert torch.allclose( + substate.get_strict("energy"), si_sim_state_extra.get_strict("energy") + ) + assert substate.get_strict("max_steps") == si_sim_state_extra.get_strict( + "max_steps" + ) def test_slice_md_substate(si_double_sim_state: SimState) -> None: @@ -100,19 +148,30 @@ def test_slice_md_substate(si_double_sim_state: SimState) -> None: def test_concatenate_two_si_states( - si_sim_state: SimState, si_double_sim_state: SimState + si_sim_state_extra: SimState, si_double_sim_state_extra: SimState ) -> None: """Test concatenating two identical silicon states.""" # Concatenate two copies of the sim state - concatenated = ts.concatenate_states([si_sim_state, si_sim_state]) + concatenated = ts.concatenate_states([si_sim_state_extra, si_sim_state_extra]) # Check that the result is the same as the double state assert isinstance(concatenated, SimState) - assert concatenated.positions.shape == si_double_sim_state.positions.shape - assert concatenated.masses.shape == si_double_sim_state.masses.shape - assert concatenated.cell.shape == si_double_sim_state.cell.shape - assert concatenated.atomic_numbers.shape == si_double_sim_state.atomic_numbers.shape - assert concatenated.system_idx.shape == si_double_sim_state.system_idx.shape + assert concatenated.positions.shape == si_double_sim_state_extra.positions.shape + assert concatenated.masses.shape == si_double_sim_state_extra.masses.shape + assert concatenated.cell.shape == si_double_sim_state_extra.cell.shape + assert ( + concatenated.atomic_numbers.shape + == si_double_sim_state_extra.atomic_numbers.shape + ) + assert concatenated.system_idx.shape == si_double_sim_state_extra.system_idx.shape + assert ( + concatenated.get_strict("charge").shape + == si_double_sim_state_extra.get_strict("charge").shape + ) + assert ( + concatenated.get_strict("energy").shape + == si_double_sim_state_extra.get_strict("energy").shape + ) # Check system indices tensor_args = dict(dtype=torch.int64, device=si_sim_state.device) @@ -133,6 +192,17 @@ def test_concatenate_two_si_states( si_double_sim_state.positions[mask_double], ) + # check that the extra attributes are concatenated correctly + assert torch.allclose( + concatenated.get_strict("charge"), si_double_sim_state_extra.get_strict("charge") + ) + assert torch.allclose( + concatenated.get_strict("energy"), si_double_sim_state_extra.get_strict("energy") + ) + assert concatenated.get_strict("max_steps") == si_double_sim_state_extra.get_strict( + "max_steps" + ) + def test_concatenate_si_and_fe_states( si_sim_state: SimState, fe_supercell_sim_state: SimState @@ -228,16 +298,26 @@ def test_concatenate_double_si_and_fe_states( assert torch.allclose(fe_slice.positions, fe_supercell_sim_state.positions) -def test_split_state(si_double_sim_state: SimState) -> None: +def test_concatenate_states_with_inconsistent_extra_attributes( + si_sim_state: SimState, si_sim_state_extra: SimState +) -> None: + """We should only be able to concat states with the same extra attributes.""" + with pytest.raises(ValueError): + ts.concatenate_states([si_sim_state, si_sim_state_extra]) + + +def test_split_state(si_double_sim_state_extra: SimState) -> None: """Test splitting a state into a list of states.""" - states = si_double_sim_state.split() - assert len(states) == si_double_sim_state.n_systems + states = si_double_sim_state_extra.split() + assert len(states) == si_double_sim_state_extra.n_systems for state in states: assert isinstance(state, SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) assert state.atomic_numbers.shape == (8,) + assert state.get_strict("charge").shape == (8,) + assert state.get_strict("energy").shape == (1,) assert torch.allclose(state.system_idx, torch.zeros_like(state.system_idx)) @@ -287,6 +367,31 @@ def test_pop_states( assert kept_state.system_idx.shape == (len_kept,) +def test_pop_states_with_extra_attributes(si_double_sim_state_extra: SimState) -> None: + """Test popping states with extra attributes.""" + kept_state, popped_states = _pop_states( + si_double_sim_state_extra, + torch.tensor([0], device=si_double_sim_state_extra.device), + ) + assert isinstance(kept_state, SimState) + assert isinstance(popped_states, list) + assert len(popped_states) == 1 + assert isinstance(popped_states[0], SimState) + assert popped_states[0].positions.shape == si_double_sim_state_extra.positions.shape + assert popped_states[0].get_strict("charge").shape == (8,) + assert popped_states[0].get_strict("energy").shape == (1,) + assert popped_states[0].get_strict( + "max_steps" + ) == si_double_sim_state_extra.get_strict("max_steps") + + assert kept_state.positions.shape == (8, 3) + assert kept_state.get_strict("charge").shape == (8,) + assert kept_state.get_strict("energy").shape == (1,) + assert kept_state.get_strict("max_steps") == si_double_sim_state_extra.get_strict( + "max_steps" + ) + + def test_initialize_state_from_structure(si_structure: "Structure") -> None: """Test conversion from pymatgen Structure to state tensors.""" state = ts.initialize_state([si_structure], DEVICE, torch.float64) @@ -650,3 +755,72 @@ def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None: "New state doesn't have correct device!" ) assert si_sim_state is not new_state_gpu, "New state is not a different object!" + + +def test_state_clone(si_sim_state_extra: SimState) -> None: + """Test the clone method of SimState.""" + cloned_state = si_sim_state_extra.clone() + assert isinstance(cloned_state, SimState) + assert cloned_state is not si_sim_state_extra + + for attr in si_sim_state_extra.attributes: + attr_cloned = getattr(cloned_state, attr) + attr_original = getattr(si_sim_state_extra, attr) + if isinstance(attr_cloned, torch.Tensor): + assert torch.allclose(attr_cloned, attr_original) + else: + assert attr_cloned == attr_original + assert attr_cloned is not attr_original + + +def test_state_get_attribute_with_extra_attributes(si_sim_state_extra: SimState) -> None: + """Test the __getitem__ method of SimState with extra attributes.""" + state = si_sim_state_extra + assert torch.equal(state.get_strict("charge"), state._extra_atom_attributes["charge"]) + assert torch.equal( + state.get_strict("energy"), state._extra_system_attributes["energy"] + ) + assert state.get_strict("max_steps") == state._extra_global_attributes["max_steps"] + + +def test_state_set_attribute_with_extra_attributes( + si_sim_state: SimState, si_sim_state_extra: SimState +) -> None: + """Test the __setitem__ method of SimState with extra attributes.""" + target_state = si_sim_state_extra + si_sim_state.set("charge", si_sim_state_extra.get_strict("charge"), "atom") + si_sim_state.set("energy", si_sim_state_extra.get_strict("energy"), "system") + si_sim_state.set("max_steps", si_sim_state_extra.get_strict("max_steps"), "global") + + assert torch.equal( + si_sim_state.get_strict("charge"), target_state.get_strict("charge") + ) + assert torch.equal( + si_sim_state.get_strict("energy"), target_state.get_strict("energy") + ) + assert si_sim_state.get_strict("max_steps") == target_state.get_strict("max_steps") + + +def test_state_set_attribute_must_specify_kind( + si_sim_state: SimState, si_sim_state_extra: SimState +) -> None: + with pytest.raises(ValueError) as exc_info: + si_sim_state.set("charge", torch.randn(si_sim_state.n_atoms, 1)) + assert "Kind must be specified for extra attributes" in str(exc_info.value) + + +def test_state_set_attribute_with_default_attributes(si_sim_state: SimState) -> None: + """Test the __setitem__ method of SimState with default attributes.""" + new_positions = torch.randn(si_sim_state.n_atoms, 3) + new_cell = torch.randn(si_sim_state.n_systems, 3, 3) + new_pbc = torch.tensor([True, False, True]) + + si_sim_state.set("positions", new_positions) + si_sim_state.set("cell", new_cell) + + # test that we can optionally specify a kind + si_sim_state.set("pbc", new_pbc, kind="global") + + assert torch.allclose(si_sim_state.positions, new_positions) + assert torch.allclose(si_sim_state.cell, new_cell) + assert torch.allclose(si_sim_state.pbc, new_pbc) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 58f233e84..e4e1edec3 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -27,6 +27,8 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): """ from abc import ABC, abstractmethod +from typing import TypedDict +from typing_extensions import deprecated import torch @@ -35,13 +37,28 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): from torch_sim.typing import MemoryScaling, StateDict +class ModelInterfaceOutput(TypedDict): + """The expected output of a model forward pass implementation.""" + + atom_attributes: dict[str, torch.Tensor] + system_attributes: dict[str, torch.Tensor] + global_attributes: dict[str, torch.Tensor] + + # deprecated attributes. People who've written their own model interfaces should move + # away from this and write their results to atom_attributes, system_attributes, and + # global_attributes. + energy: torch.Tensor | None + forces: torch.Tensor | None + stress: torch.Tensor | None + + class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. This interface provides a common structure for all energy and force models, ensuring they implement the required methods and properties. It defines how - models should process atomic positions and system information to compute energies, - forces, and stresses. + models should process atomic positions and system information to compute + system-wide attributes like energies/stresses, or atom-wise attributes like forces. Attributes: device (torch.device): Device where the model runs computations. @@ -133,7 +150,7 @@ def memory_scales_with(self) -> MemoryScaling: return getattr(self, "_memory_scales_with", "n_atoms_x_density") @abstractmethod - def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]: + def forward(self, state: SimState | StateDict, **kwargs) -> ModelInterfaceOutput: """Calculate energies, forces, and stresses for a atomistic system. This is the main computational method that all model implementations must provide. @@ -151,27 +168,32 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens **kwargs: Additional model-specific parameters. Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_systems] - - "forces": Atomic forces with shape [n_atoms, 3] - - "stress": Stress tensor with shape [n_systems, 3, 3] (if - compute_stress=True) - - May include additional model-specific outputs + ModelInterfaceOutput: Computed properties: + - "atom_attributes": Dictionary of atom-wise attributes + - "system_attributes": Dictionary of system-wide attributes + - "global_attributes": Dictionary of global attributes Examples: ```py # Compute energies and forces with a model output = model.forward(state) - energy = output["energy"] - forces = output["forces"] - stress = output.get("stress", None) + energy = output["system_attributes"]["energy"] + forces = output["atom_attributes"]["forces"] + stress = output["system_attributes"].get("stress") ``` """ +# TODO: we should put this logic inside __init_subclass__ of Modelinterface to +# automatically validate the model outputs when the model is subclassed. def validate_model_outputs( # noqa: C901, PLR0915 - model: ModelInterface, device: torch.device, dtype: torch.dtype + model: ModelInterface, + device: torch.device, + dtype: torch.dtype, + expected_output_atom_attributes: set[str], + expected_output_system_attributes: set[str], + expected_output_global_attributes: set[str], ) -> None: """Validate the outputs of a model implementation against the interface requirements. @@ -183,7 +205,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 model (ModelInterface): Model implementation to validate. device (torch.device): Device to run the validation tests on. dtype (torch.dtype): Data type to use for validation tensors. - + expected_output_attributes (set[str]): The attributes that the model is expected + to return. Raises: AssertionError: If the model doesn't conform to the required interface, including issues with output shapes, types, or behavior consistency. @@ -243,46 +266,54 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") # assert model output has the correct keys - if "energy" not in model_output: - raise ValueError("energy not in model output") - if force_computed and "forces" not in model_output: - raise ValueError("forces not in model output") - if stress_computed and "stress" not in model_output: - raise ValueError("stress not in model output") - - # assert model output shapes are correct - if model_output["energy"].shape != (2,): - raise ValueError(f"{model_output['energy'].shape=} != (2,)") - if force_computed and model_output["forces"].shape != (20, 3): - raise ValueError(f"{model_output['forces'].shape=} != (20, 3)") - if stress_computed and model_output["stress"].shape != (2, 3, 3): - raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)") + for attr in expected_output_atom_attributes: + if attr not in model_output["atom_attributes"]: + raise ValueError(f"{attr} not in model output") + for attr in expected_output_system_attributes: + if attr not in model_output["system_attributes"]: + raise ValueError(f"{attr} not in model output") + for attr in expected_output_global_attributes: + if attr not in model_output["global_attributes"]: + raise ValueError(f"{attr} not in model output") si_state = ts.io.atoms_to_state([si_atoms], device, dtype) fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) si_model_output = model.forward(si_state) - if not torch.allclose( - si_model_output["energy"], model_output["energy"][0], atol=10e-3 - ): - raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}") - if not torch.allclose( - forces := si_model_output["forces"], - expected_forces := model_output["forces"][: si_state.n_atoms], - atol=10e-3, - ): - raise ValueError(f"{forces=} != {expected_forces=}") - fe_model_output = model.forward(fe_state) - si_model_output = model.forward(si_state) - if not torch.allclose( - fe_model_output["energy"], model_output["energy"][1], atol=10e-2 - ): - raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}") - if not torch.allclose( - forces := fe_model_output["forces"], - expected_forces := model_output["forces"][si_state.n_atoms :], - atol=10e-2, - ): - raise ValueError(f"{forces=} != {expected_forces=}") + for attr in expected_output_atom_attributes: + if attr in model_output["atom_attributes"]: + si_attr = si_model_output["atom_attributes"][attr] + batched_attr = model_output["atom_attributes"][attr] + expected_attr = batched_attr[: si_state.n_atoms] + if not torch.allclose(si_attr, expected_attr, atol=10e-3): + raise ValueError(f"{attr}: {si_attr=} != {expected_attr=}") + + fe_attr = fe_model_output["atom_attributes"][attr] + expected_fe_attr = batched_attr[si_state.n_atoms :] + if not torch.allclose(fe_attr, expected_fe_attr, atol=10e-2): + raise ValueError(f"{attr}: {fe_attr=} != {expected_fe_attr=}") + + for attr in expected_output_system_attributes: + if attr in model_output["system_attributes"]: + si_attr = si_model_output["system_attributes"][attr] + batched_attr = model_output["system_attributes"][attr] + expected_attr = batched_attr[0] + if not torch.allclose(si_attr, expected_attr, atol=10e-3): + raise ValueError(f"{attr}: {si_attr=} != {expected_attr=}") + + fe_attr = fe_model_output["system_attributes"][attr] + expected_fe_attr = batched_attr[1] + if not torch.allclose(fe_attr, expected_fe_attr, atol=10e-2): + raise ValueError(f"{attr}: {fe_attr=} != {expected_fe_attr=}") + + for attr in expected_output_global_attributes: + if attr in model_output["global_attributes"]: + si_attr = si_model_output["global_attributes"][attr] + fe_attr = fe_model_output["global_attributes"][attr] + batched_attr = model_output["global_attributes"][attr] + if not torch.allclose(si_attr, batched_attr, atol=10e-3): + raise ValueError(f"{attr}: {si_attr=} != {batched_attr=}") + if not torch.allclose(fe_attr, batched_attr, atol=10e-2): + raise ValueError(f"{attr}: {fe_attr=} != {batched_attr=}") diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b2059aac0..0f2bb4027 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -574,18 +574,6 @@ def static( properties=properties, ) - @dataclass(kw_only=True) - class StaticState(SimState): - energy: torch.Tensor - forces: torch.Tensor - stress: torch.Tensor - - _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 - _system_attributes = SimState._system_attributes | { # noqa: SLF001 - "energy", - "stress", - } - all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames @@ -606,25 +594,21 @@ class StaticState(SimState): ) model_outputs = model(sub_state) - static_state = StaticState( - positions=sub_state.positions, - masses=sub_state.masses, - cell=sub_state.cell, - pbc=sub_state.pbc, - atomic_numbers=sub_state.atomic_numbers, - system_idx=sub_state.system_idx, - energy=model_outputs["energy"], - forces=( - model_outputs["forces"] - if model.compute_forces - else torch.full_like(sub_state.positions, fill_value=float("nan")) - ), - stress=( - model_outputs["stress"] - if model.compute_stress - else torch.full_like(sub_state.cell, fill_value=float("nan")) - ), - ) + static_state = sub_state.clone() + for attribute_name, value in model_outputs["atom_attributes"].items(): + static_state.set(attribute_name, value, "atom") + for attribute_name, value in model_outputs["system_attributes"].items(): + static_state.set(attribute_name, value, "system") + for attribute_name, value in model_outputs["global_attributes"].items(): + static_state.set(attribute_name, value, "global") + + # Handle deprecated model outputs + if "energy" in model_outputs: + static_state.set("energy", model_outputs["energy"], "system") + if "forces" in model_outputs: + static_state.set("forces", model_outputs["forces"], "atom") + if "stress" in model_outputs: + static_state.set("stress", model_outputs["stress"], "system") props = trajectory_reporter.report(static_state, 0, model=model) all_props.extend(props) diff --git a/torch_sim/state.py b/torch_sim/state.py index 813354fe8..533af247a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -10,7 +10,7 @@ from collections import defaultdict from collections.abc import Generator, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, Set, TypeVar import torch @@ -98,6 +98,10 @@ def pbc(self) -> torch.Tensor: """A getter for pbc that tells type checkers it's always defined.""" return self.pbc + _extra_atom_attributes: dict[str, torch.Tensor] = field(default_factory=dict) + _extra_system_attributes: dict[str, torch.Tensor] = field(default_factory=dict) + _extra_global_attributes: dict[str, Any] = field(default_factory=dict) + _atom_attributes: ClassVar[set[str]] = { "positions", "masses", @@ -204,7 +208,7 @@ def volume(self) -> torch.Tensor: def attributes(self) -> dict[str, torch.Tensor]: """Get all public attributes of the state.""" return { - attr: getattr(self, attr) + attr: self.get(attr) for attr in self._atom_attributes | self._system_attributes | self._global_attributes @@ -402,6 +406,74 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def get(self, attribute_name: str) -> Any: + """Get the attribute of the state.""" + if ( + attribute_name + in self._atom_attributes | self._system_attributes | self._global_attributes + ): + return getattr(self, attribute_name) + + if attribute_name in self._extra_atom_attributes: + return self._extra_atom_attributes[attribute_name] + if attribute_name in self._extra_system_attributes: + return self._extra_system_attributes[attribute_name] + if attribute_name in self._extra_global_attributes: + return self._extra_global_attributes[attribute_name] + return None + + def get_strict(self, attribute_name: str) -> Any: + """Get the attribute of the state. + + Raises a ValueError if the attribute is not found. + """ + res = self.get(attribute_name) + if res is None: + raise ValueError(f"Attribute '{attribute_name}' not found in state") + return res + + def set( + self, + attribute_name: str, + value: Any, + kind: Literal["atom", "system", "global"] | None = None, + ) -> None: + """Set the attribute of the state.""" + # 1) Handle special cases for default attributes + all_default_attributes = ( + self._atom_attributes | self._system_attributes | self._global_attributes + ) + if attribute_name in all_default_attributes: + # no need to check kind since it's already a default attribute + setattr(self, attribute_name, value) + return + + # 2) validate the kind and value + if kind is None: + raise ValueError("Kind must be specified for extra attributes") + if kind in ("atom", "system") and not isinstance(value, torch.Tensor): + raise ValueError(f"Value for '{attribute_name}' must be a torch.Tensor") + + # 3) Write the value to the appropriate extra attribute + if kind == "atom": + if value.shape[0] != self.n_atoms: + raise ValueError( + f"Value for '{attribute_name}' must have shape (n_atoms, ...)" + ) + self._extra_atom_attributes[attribute_name] = value + elif kind == "system": + if value.shape[0] != self.n_systems: + raise ValueError( + f"Value for '{attribute_name}' must have shape (n_systems, ...)" + ) + self._extra_system_attributes[attribute_name] = value + elif kind == "global": + self._extra_global_attributes[attribute_name] = value + else: + raise ValueError( + f"Invalid kind: {kind}. Must be 'atom', 'system', or 'global'." + ) + def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. @@ -617,7 +689,9 @@ def _state_to_device[T: SimState]( def get_attrs_for_scope( - state: SimState, scope: Literal["per-atom", "per-system", "global"] + state: SimState, + scope: Literal["per-atom", "per-system", "global"], + attribute_kind: Literal["only_default", "only_extra", "all"] = "all", ) -> Generator[tuple[str, Any], None, None]: """Get attributes for a given scope. @@ -629,17 +703,27 @@ def get_attrs_for_scope( Returns: Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ + attr_names = set[str]() match scope: case "per-atom": - attr_names = state._atom_attributes # noqa: SLF001 + if attribute_kind in ["only_default", "all"]: + attr_names |= state._atom_attributes # noqa: SLF001 + if attribute_kind in ["only_extra", "all"]: + attr_names |= state._extra_atom_attributes.keys() # noqa: SLF001 case "per-system": - attr_names = state._system_attributes # noqa: SLF001 + if attribute_kind in ["only_default", "all"]: + attr_names |= state._system_attributes # noqa: SLF001 + if attribute_kind in ["only_extra", "all"]: + attr_names |= state._extra_system_attributes.keys() # noqa: SLF001 case "global": - attr_names = state._global_attributes # noqa: SLF001 + if attribute_kind in ["only_default", "all"]: + attr_names |= state._global_attributes # noqa: SLF001 + if attribute_kind in ["only_extra", "all"]: + attr_names |= state._extra_global_attributes.keys() # noqa: SLF001 case _: raise ValueError(f"Unknown scope: {scope!r}") for attr_name in attr_names: - yield attr_name, getattr(state, attr_name) + yield attr_name, state.get(attr_name) def _filter_attrs_by_mask( @@ -870,15 +954,31 @@ def concatenate_states[T: SimState]( # noqa: C901 if not all(isinstance(state, state_class) for state in states): raise TypeError("All states must be of the same type") + # ensure all states have the same extra attributes + first_state_attribute_names = extra_attribute_names(first_state) + if not all( + extra_attribute_names(state) == first_state_attribute_names for state in states + ): + raise ValueError( + "All states must have the same extra attributes. Currently, the first state " + "has these extra attributes: {first_state_attribute_names}" + ) + # Use the target device or default to the first state's device target_device = device or first_state.device # Initialize result with global properties from first state - concatenated = dict(get_attrs_for_scope(first_state, "global")) + concatenated = dict(get_attrs_for_scope(first_state, "global", "only_default")) + concatenated_global_extra = dict( + get_attrs_for_scope(first_state, "global", "only_extra") + ) # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + per_atom_tensors_extra = defaultdict[str, list[torch.Tensor]](list) + per_system_tensors_extra = defaultdict[str, list[torch.Tensor]](list) + new_system_indices = [] system_offset = 0 @@ -889,16 +989,30 @@ def concatenate_states[T: SimState]( # noqa: C901 state = state.to(target_device) # Collect per-atom properties - for prop, val in get_attrs_for_scope(state, "per-atom"): + for prop, val in get_attrs_for_scope( + state, "per-atom", attribute_kind="only_default" + ): if prop == "system_idx": # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) + for prop, val in get_attrs_for_scope( + state, "per-atom", attribute_kind="only_extra" + ): + per_atom_tensors_extra[prop].append(val) + # Collect per-system properties - for prop, val in get_attrs_for_scope(state, "per-system"): + for prop, val in get_attrs_for_scope( + state, "per-system", attribute_kind="only_default" + ): per_system_tensors[prop].append(val) + for prop, val in get_attrs_for_scope( + state, "per-system", attribute_kind="only_extra" + ): + per_system_tensors_extra[prop].append(val) + # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset @@ -907,21 +1021,53 @@ def concatenate_states[T: SimState]( # noqa: C901 # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): - # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) for prop, tensors in per_system_tensors.items(): - # if tensors: if isinstance(tensors[0], torch.Tensor): concatenated[prop] = torch.cat(tensors, dim=0) else: # Non-tensor attributes, take first one (they should all be identical) concatenated[prop] = tensors[0] + # concatenate the extra attributes + concatenated_per_atom_extra = dict( + get_attrs_for_scope(first_state, "per-atom", "only_extra") + ) + concatenated_per_system_extra = dict( + get_attrs_for_scope(first_state, "per-system", "only_extra") + ) + for prop, tensors in per_atom_tensors_extra.items(): + concatenated_per_atom_extra[prop] = torch.cat(tensors, dim=0) + for prop, tensors in per_system_tensors_extra.items(): + if isinstance(tensors[0], torch.Tensor): + concatenated_per_system_extra[prop] = torch.cat(tensors, dim=0) + else: # Non-tensor attributes, take first one (they should all be identical) + concatenated_per_system_extra[prop] = tensors[0] + # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class - return state_class(**concatenated) + new_state = state_class(**concatenated) + + # Add the extra attributes (since these attributes are not in the class' constructor) + for prop, value in concatenated_per_atom_extra.items(): + new_state.set(prop, value, kind="atom") + for prop, value in concatenated_per_system_extra.items(): + new_state.set(prop, value, kind="system") + for prop, value in concatenated_global_extra.items(): + new_state.set(prop, value, kind="global") + + return new_state + + +def extra_attribute_names(state: SimState) -> set[str]: + """Get the names of the extra attributes of the state.""" + return ( + state._extra_atom_attributes.keys() # noqa: SLF001 + | state._extra_system_attributes.keys() # noqa: SLF001 + | state._extra_global_attributes.keys() # noqa: SLF001 + ) def initialize_state(