From 29ccf8f75ec0acbbf837d45504c7258e65fbae4e Mon Sep 17 00:00:00 2001 From: orionarcher Date: Mon, 2 Mar 2026 14:00:51 -0500 Subject: [PATCH 1/4] Remove StateDict type and ensure_sim_state helper StateDict = dict was a type alias allowing raw dicts to be passed as state to integrators, optimizers, and models. All callers now pass a proper SimState, so remove the alias, the ensure_sim_state conversion helper, and all StateDict | SimState union annotations across integrators, optimizers, models, the interface ABC, tests, and examples. --- examples/scripts/1_introduction.py | 7 ++-- tests/test_fix_symmetry.py | 5 +-- tests/test_optimizers.py | 17 +++------- torch_sim/integrators/npt.py | 20 ++++------- torch_sim/integrators/nve.py | 11 +++--- torch_sim/integrators/nvt.py | 24 +++++-------- torch_sim/models/fairchem.py | 14 +++----- torch_sim/models/fairchem_legacy.py | 12 +++---- torch_sim/models/graphpes_framework.py | 8 +---- torch_sim/models/interface.py | 16 ++++----- torch_sim/models/lennard_jones.py | 43 ++++-------------------- torch_sim/models/mace.py | 11 ++---- torch_sim/models/mattersim.py | 33 ++++++------------ torch_sim/models/metatomic.py | 21 +++--------- torch_sim/models/morse.py | 23 ++++--------- torch_sim/models/orb.py | 14 +++----- torch_sim/models/particle_life.py | 34 ++----------------- torch_sim/models/sevennet.py | 14 +++----- torch_sim/models/soft_sphere.py | 27 ++++----------- torch_sim/optimizers/bfgs.py | 7 ++-- torch_sim/optimizers/fire.py | 8 ++--- torch_sim/optimizers/gradient_descent.py | 7 ++-- torch_sim/optimizers/lbfgs.py | 8 ++--- torch_sim/state.py | 9 +---- torch_sim/typing.py | 2 -- 25 files changed, 106 insertions(+), 289 deletions(-) diff --git a/examples/scripts/1_introduction.py b/examples/scripts/1_introduction.py index 292b5537e..9b5f4c49a 100644 --- a/examples/scripts/1_introduction.py +++ b/examples/scripts/1_introduction.py @@ -16,6 +16,7 @@ from ase.build import bulk from mace.calculators.foundations_models import mace_mp +import torch_sim as ts from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.models.mace import MaceModel, MaceUrls @@ -90,8 +91,8 @@ # Masses for Argon (39.948 amu) masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype) -# State dict -state = dict( +# SimState +state = ts.SimState( positions=positions, masses=masses, cell=cell.unsqueeze(0), @@ -174,7 +175,7 @@ # Now we can pass them to the model results = batched_model( - dict( + ts.SimState( positions=positions, masses=masses_si, cell=cell, diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index a953cfe88..a770168bb 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -15,7 +15,6 @@ from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import UnbatchedLennardJonesModel from torch_sim.symmetrize import get_symmetry_datasets -from torch_sim.typing import StateDict pytest.importorskip("moyopy") @@ -100,9 +99,7 @@ def __init__( self._compute_stress = model.compute_stress self._compute_forces = model.compute_forces - def forward( - self, state: ts.SimState | StateDict, **kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: """Forward pass with added noise.""" results = self.model(state, **kwargs) for key in ("forces", "stress"): diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index d968a767d..ad79d4a0f 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1,6 +1,5 @@ import copy from collections.abc import Callable -from dataclasses import fields from functools import partial from typing import Any, get_args @@ -572,13 +571,9 @@ def test_simple_optimizer_init_with_dict( ar_supercell_sim_state: SimState, lj_model: ModelInterface, ) -> None: - """Test simple optimizer init_fn with a SimState dictionary.""" - state_dict = { - field.name: getattr(ar_supercell_sim_state, field.name) - for field in fields(ar_supercell_sim_state) - } + """Test simple optimizer init_fn with a SimState.""" init_fn, _ = ts.OPTIM_REGISTRY[optimizer_fn] - opt_state = init_fn(model=lj_model, state=state_dict) + opt_state = init_fn(model=lj_model, state=ar_supercell_sim_state) assert isinstance(opt_state, expected_state_type) assert opt_state.energy is not None assert opt_state.forces is not None @@ -822,15 +817,11 @@ def test_cell_optimizer_init_with_dict_and_cell_factor( ar_supercell_sim_state: SimState, lj_model: ModelInterface, ) -> None: - """Test cell optimizer init_fn with dict state and explicit cell_factor.""" - state_dict = { - f.name: getattr(ar_supercell_sim_state, f.name) - for f in fields(ar_supercell_sim_state) - } + """Test cell optimizer init_fn with explicit cell_factor.""" init_fn, _ = ts.OPTIM_REGISTRY[optimizer_fn] opt_state = init_fn( model=lj_model, - state=state_dict, + state=ar_supercell_sim_state, cell_factor=cell_factor_val, cell_filter=cell_filter, ) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index d15fab7fd..da2d68391 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -19,8 +19,7 @@ ) from torch_sim.integrators.nvt import _vrescale_update from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState, ensure_sim_state -from torch_sim.typing import StateDict +from torch_sim.state import SimState def _randn_for_state(state: MDState, shape: torch.Size | tuple[int, ...]) -> torch.Tensor: @@ -502,7 +501,7 @@ def _compute_cell_force( def npt_langevin_init( - state: SimState | StateDict, + state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, @@ -524,8 +523,7 @@ def npt_langevin_init( Args: model (ModelInterface): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - state (MDState | StateDict): Either a MDState object or a dictionary - containing positions, masses, cell, pbc + state (SimState): SimState containing positions, masses, cell, pbc kT (torch.Tensor): Target temperature in energy units, either scalar or with shape [n_systems] dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems] @@ -561,8 +559,6 @@ def npt_langevin_init( kT = torch.as_tensor(kT, device=device, dtype=dtype) dt = torch.as_tensor(dt, device=device, dtype=dtype) - state = ensure_sim_state(state) - if alpha.ndim == 0: alpha = alpha.expand(state.n_systems) if cell_alpha.ndim == 0: @@ -1280,7 +1276,7 @@ def _npt_nose_hoover_inner_step( def npt_nose_hoover_init( - state: SimState | StateDict, + state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, @@ -1303,7 +1299,7 @@ def npt_nose_hoover_init( Args: model (ModelInterface): Model to compute forces and energies - state: Initial system state as MDState or dict containing positions, masses, + state: Initial system state as SimState containing positions, masses, cell, and PBC information kT: Target temperature in energy units external_pressure: Target external pressure @@ -1335,7 +1331,6 @@ def npt_nose_hoover_init( - Cell dynamics use logarithmic coordinates for volume updates - All cell properties are properly initialized with batch dimensions """ - state = ensure_sim_state(state) device, dtype = state.device, state.dtype dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype) kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype) @@ -2301,7 +2296,7 @@ def npt_crescale_isotropic_step( def npt_crescale_init( - state: SimState | StateDict, + state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, @@ -2321,7 +2316,7 @@ def npt_crescale_init( To seed the RNG set ``state.rng = seed`` before calling. Args: - state: Initial system state as MDState or dict containing positions, masses, + state: Initial system state as SimState containing positions, masses, cell, and PBC information model (ModelInterface): Model to compute forces and energies kT: Target temperature in energy units @@ -2330,7 +2325,6 @@ def npt_crescale_init( isothermal_compressibility: Isothermal compressibility of the system. """ device, dtype = model.device, model.dtype - state = ensure_sim_state(state) # Convert all parameters to tensors with correct device and dtype dt = torch.as_tensor(dt, device=device, dtype=dtype) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 993da6a45..ca7241269 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -11,12 +11,11 @@ position_step, ) from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState, ensure_sim_state -from torch_sim.typing import StateDict +from torch_sim.state import SimState def nve_init( - state: SimState | StateDict, + state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, @@ -33,8 +32,8 @@ def nve_init( Args: model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - state: Either a SimState object or a dictionary containing positions, - masses, cell, pbc, and other required state variables + state: SimState containing positions, masses, cell, pbc, and other + required state variables kT: Temperature in energy units for initializing momenta, scalar or with shape [n_systems] @@ -46,8 +45,6 @@ def nve_init( - Initial velocities sampled from Maxwell-Boltzmann distribution - Time integration error scales as O(dt²) """ - state = ensure_sim_state(state) - model_output = model(state) momenta = getattr( diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index d4101b1aa..10f2509ad 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -18,8 +18,7 @@ velocity_verlet_step, ) from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState, ensure_sim_state -from torch_sim.typing import StateDict +from torch_sim.state import SimState def _ou_step( @@ -85,7 +84,7 @@ def _ou_step( def nvt_langevin_init( - state: SimState | StateDict, + state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, @@ -102,8 +101,8 @@ def nvt_langevin_init( Args: model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - state: Either a SimState object or a dictionary containing positions, - masses, cell, pbc, and other required state vars + state: SimState containing positions, masses, cell, pbc, and other + required state vars kT: Temperature in energy units for initializing momenta, either scalar or with shape [n_systems] @@ -116,8 +115,6 @@ def nvt_langevin_init( at the specified temperature. This provides a proper thermal initial state for the subsequent Langevin dynamics. """ - state = ensure_sim_state(state) - model_output = model(state) momenta = getattr( @@ -251,7 +248,7 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: def nvt_nose_hoover_init( - state: SimState | StateDict, + state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, @@ -272,7 +269,7 @@ def nvt_nose_hoover_init( To seed the RNG set ``state.rng = seed`` before calling. Args: - state: Initial system state as SimState or dict + state: Initial system state as SimState model: Neural network model that computes energies and forces kT: Target temperature in energy units dt: Integration timestep @@ -292,7 +289,6 @@ def nvt_nose_hoover_init( - Chain variables evolve to maintain target temperature - Time-reversible when integrated with appropriate algorithms """ - state = ensure_sim_state(state) dt_tensor = torch.as_tensor(dt, device=state.device, dtype=state.dtype) kT_tensor = torch.as_tensor(kT, device=state.device, dtype=state.dtype) tau_tensor = torch.as_tensor( @@ -570,7 +566,7 @@ def _vrescale_update[T: MDState]( def nvt_vrescale_init( - state: SimState | StateDict, + state: SimState, model: ModelInterface, *, kT: float | torch.Tensor, @@ -588,8 +584,8 @@ def nvt_vrescale_init( Args: model: Neural network model that computes energies and forces. Must return a dict with 'energy' and 'forces' keys. - state: Either a SimState object or a dictionary containing positions, - masses, cell, pbc, and other required state vars + state: SimState containing positions, masses, cell, pbc, and other + required state vars kT: Temperature in energy units for initializing momenta, either scalar or with shape [n_systems] @@ -602,8 +598,6 @@ def nvt_vrescale_init( at the specified temperature. The V-Rescale thermostat provides proper canonical sampling through stochastic velocity rescaling. """ - state = ensure_sim_state(state) - model_output = model(state) momenta = getattr( diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index bbcfc274e..32d30f558 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -18,7 +18,6 @@ import torch from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState, ensure_sim_state try: @@ -45,7 +44,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: if typing.TYPE_CHECKING: from collections.abc import Callable - from torch_sim.typing import StateDict + from torch_sim.state import SimState class FairChemModel(ModelInterface): @@ -169,15 +168,12 @@ def device(self) -> torch.device: """Return the device where the model is located.""" return self._device - def forward( - self, state: SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute energies, forces, and other properties. Args: - state (SimState | StateDict): State object containing positions, cells, - atomic numbers, and other system information. If a dictionary is provided, - it will be converted to a SimState. + state (SimState): State object containing positions, cells, atomic numbers, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -186,7 +182,7 @@ def forward( - forces (torch.Tensor): Forces with shape [n_atoms, 3] - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] """ - sim_state = ensure_sim_state(state) + sim_state = state if sim_state.device != self._device: sim_state = sim_state.to(self._device) diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py index c3dd2eb9c..a006cb201 100644 --- a/torch_sim/models/fairchem_legacy.py +++ b/torch_sim/models/fairchem_legacy.py @@ -30,11 +30,10 @@ import torch from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState, ensure_sim_state if typing.TYPE_CHECKING: - from torch_sim.typing import StateDict + from torch_sim.state import SimState def _validate_fairchem_version() -> None: @@ -363,7 +362,7 @@ def load_checkpoint( print("Unable to load checkpoint!") def forward( # noqa: C901 - self, state: SimState | StateDict, **_kwargs: object + self, state: SimState, **_kwargs: object ) -> dict[str, torch.Tensor]: """Perform forward pass to compute energies, forces, and other properties. @@ -371,9 +370,8 @@ def forward( # noqa: C901 such as energy, forces, and stresses. Args: - state (SimState | StateDict): State object containing positions, cells, - atomic numbers, and other system information. If a dictionary is provided, - it will be converted to a SimState. + state (SimState): State object containing positions, cells, atomic numbers, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -387,7 +385,7 @@ def forward( # noqa: C901 The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ - sim_state = ensure_sim_state(state) + sim_state = state if sim_state.device != self._device: sim_state = sim_state.to(self._device) diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py index 8fb7a0eb7..0bca0cdd5 100644 --- a/torch_sim/models/graphpes_framework.py +++ b/torch_sim/models/graphpes_framework.py @@ -23,8 +23,6 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import ensure_sim_state -from torch_sim.typing import StateDict try: @@ -172,9 +170,7 @@ def __init__( if isinstance(cutoff_val, torch.Tensor) and cutoff_val.item() < 0.5: self._memory_scales_with = "n_atoms" - def forward( - self, state: ts.SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Forward pass for the GraphPESWrapper. Args: @@ -185,8 +181,6 @@ def forward( Dictionary containing the computed energies, forces, and stresses (where applicable) """ - state = ensure_sim_state(state) - cutoff = self._gp_model.cutoff if not isinstance(cutoff, torch.Tensor): raise TypeError("GraphPES model cutoff must be a tensor") diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 8b37c0745..adc0de8b7 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -32,7 +32,7 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): import torch_sim as ts from torch_sim.state import SimState -from torch_sim.typing import MemoryScaling, StateDict +from torch_sim.typing import MemoryScaling class ModelInterface(torch.nn.Module, ABC): @@ -133,7 +133,7 @@ def memory_scales_with(self) -> MemoryScaling: return getattr(self, "_memory_scales_with", "n_atoms_x_density") @abstractmethod - def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]: + def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """Calculate energies, forces, and stresses for a atomistic system. This is the main computational method that all model implementations must provide. @@ -141,13 +141,11 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens containing computed physical properties. Args: - state (SimState | StateDict): Simulation state or state dictionary. The state - dictionary is dependent on the model but typically must contain the - following keys: - - "positions": Atomic positions with shape [n_atoms, 3] - - "cell": Unit cell vectors with shape [n_systems, 3, 3] - - "system_idx": System indices for each atom with shape [n_atoms] - - "atomic_numbers": Atomic numbers with shape [n_atoms] (optional) + state (SimState): Simulation state containing: + - positions: Atomic positions with shape [n_atoms, 3] + - cell: Unit cell vectors with shape [n_systems, 3, 3] + - system_idx: System indices for each atom with shape [n_atoms] + - atomic_numbers: Atomic numbers with shape [n_atoms] (optional) **kwargs: Additional model-specific parameters. Returns: diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 52c22f98d..518355cca 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -32,8 +32,6 @@ from torch_sim import transforms from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import ensure_sim_state -from torch_sim.typing import StateDict DEFAULT_SIGMA = 1.0 @@ -251,8 +249,6 @@ def unbatched_forward( Notes: Neighbor lists are always used to construct interacting pairs. """ - state = ensure_sim_state(state) - positions = state.positions cell = state.row_vector_cell cell = cell.squeeze() @@ -344,18 +340,15 @@ def unbatched_forward( return results - def forward( - self, state: ts.SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute Lennard-Jones energies, forces, and stresses for a system. Main entry point for Lennard-Jones calculations that handles batched states by dispatching each system to the unbatched implementation and combining results. Args: - state (SimState | StateDict): Input state containing atomic positions, - cell vectors, and other system information. Can be a SimState object - or a dictionary with the same keys. + state (SimState): Input state containing atomic positions, cell vectors, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -385,19 +378,7 @@ def forward( energies = results["energies"] # Shape: [n_atoms] stresses = results["stresses"] # Shape: [n_atoms, 3, 3] """ - if isinstance(state, ts.SimState): - sim_state = state - else: - state_dict: StateDict = state - positions = state_dict["positions"] - sim_state = ts.SimState( - positions=positions, - masses=torch.ones_like(positions), - cell=state_dict["cell"], - pbc=state_dict["pbc"], - atomic_numbers=state_dict["atomic_numbers"], - system_idx=state_dict.get("system_idx"), - ) + sim_state = state if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: raise ValueError("System can only be inferred for batch size 1.") @@ -430,22 +411,10 @@ class LennardJonesModel(UnbatchedLennardJonesModel): """ def forward( # noqa: PLR0915 - self, state: ts.SimState | StateDict, **_kwargs: object + self, state: ts.SimState, **_kwargs: object ) -> dict[str, torch.Tensor]: """Compute Lennard-Jones properties with batched tensor operations.""" - if isinstance(state, ts.SimState): - sim_state = state - else: - state_dict: StateDict = state - positions = state_dict["positions"] - sim_state = ts.SimState( - positions=positions, - masses=torch.ones_like(positions), - cell=state_dict["cell"], - pbc=state_dict["pbc"], - atomic_numbers=state_dict["atomic_numbers"], - system_idx=state_dict.get("system_idx"), - ) + sim_state = state if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: raise ValueError("System can only be inferred for batch size 1.") diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 5a9b24a93..7dfbb7657 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -29,8 +29,6 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import ensure_sim_state -from torch_sim.typing import StateDict try: @@ -232,7 +230,7 @@ def _setup_node_attrs(self, atomic_numbers: torch.Tensor) -> None: ) def forward( # noqa: C901 - self, state: ts.SimState | StateDict, **_kwargs: object + self, state: ts.SimState, **_kwargs: object ) -> dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given atomic systems. @@ -241,9 +239,8 @@ def forward( # noqa: C901 multiple systems and constructs the necessary neighbor lists. Args: - state (SimState | StateDict): State object containing positions, cell, - and other system information. Can be either a SimState object or a - dictionary with the relevant fields. + state (SimState): State object containing positions, cell, and other + system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -258,8 +255,6 @@ def forward( # noqa: C901 or in the forward pass, or if provided in both places. ValueError: If system indices are not provided when needed. """ - state = ensure_sim_state(state) - if self.atomic_numbers_in_init: if state.positions.shape[0] != self.atomic_numbers.shape[0]: raise ValueError( diff --git a/torch_sim/models/mattersim.py b/torch_sim/models/mattersim.py index c992a2d34..ce2a68d59 100644 --- a/torch_sim/models/mattersim.py +++ b/torch_sim/models/mattersim.py @@ -14,7 +14,9 @@ try: - from mattersim.datasets.utils.convertor import GraphConvertor + from mattersim.datasets.utils.convertor import ( # codespell:ignore convertor + GraphConvertor, + ) from mattersim.forcefield.potential import batch_to_dict from torch_geometric.loader.dataloader import Collater @@ -36,8 +38,6 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: if TYPE_CHECKING: from mattersim.forcefield import Potential - from torch_sim.typing import StateDict - class MatterSimModel(ModelInterface): """Computes atomistic energies, forces and stresses using an MatterSim model. @@ -100,7 +100,7 @@ def __init__( self.two_body_cutoff = model_args["cutoff"] self.three_body_cutoff = model_args["threebody_cutoff"] - self.convertor = GraphConvertor( + self.convertor = GraphConvertor( # codespell:ignore convertor model_type="m3gnet", twobody_cutoff=self.two_body_cutoff, has_threebody=True, @@ -113,18 +113,15 @@ def __init__( "stress", ] - def forward( - self, state: ts.SimState | StateDict, **_kwargs: Any - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **_kwargs: Any) -> dict[str, torch.Tensor]: """Perform forward pass to compute energies, forces, and other properties. Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses. Args: - state (SimState | StateDict): State object containing positions, cells, - atomic numbers, and other system information. If a dictionary is provided, - it will be converted to a SimState. + state (SimState): State object containing positions, cells, atomic numbers, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -138,24 +135,14 @@ def forward( The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ - if isinstance(state, ts.SimState): - sim_state = state - else: - positions = state["positions"] - sim_state = ts.SimState( - positions=positions, - masses=torch.ones_like(positions), - cell=state["cell"], - pbc=state.get("pbc", True), - atomic_numbers=state["atomic_numbers"], - system_idx=state.get("system_idx"), - ) + sim_state = state if sim_state.device != self._device: sim_state = sim_state.to(self._device) atoms_list = ts.io.state_to_atoms(sim_state) - data_list = [self.convertor.convert(atoms) for atoms in atoms_list] + convert = self.convertor.convert # codespell:ignore convertor + data_list = [convert(atoms) for atoms in atoms_list] batched_data = Collater([], follow_batch=None, exclude_keys=None)(data_list) batched_data.to(self._device) output = self.model.forward( diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 0e3595c9c..e4ed32a35 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -21,7 +21,6 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.typing import StateDict try: @@ -148,7 +147,7 @@ def __init__( ) def forward( # noqa: C901, PLR0915 - self, state: ts.SimState | StateDict, **_kwargs: Any + self, state: ts.SimState, **_kwargs: Any ) -> dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given atomic systems. @@ -157,9 +156,8 @@ def forward( # noqa: C901, PLR0915 multiple systems as well as constructing the necessary neighbor lists. Args: - state (SimState | StateDict): State object containing positions, cell, - and other system information. Can be either a SimState object or a - dictionary with the relevant fields. + state (SimState): State object containing positions, cell, and other + system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -169,18 +167,7 @@ def forward( # noqa: C901, PLR0915 - 'stress': System stresses with shape [n_systems, 3, 3] if compute_stress=True """ - if isinstance(state, ts.SimState): - sim_state = state - else: - positions = state["positions"] - sim_state = ts.SimState( - positions=positions, - masses=torch.ones_like(positions), - cell=state["cell"], - pbc=state.get("pbc", True), - atomic_numbers=state["atomic_numbers"], - system_idx=state.get("system_idx"), - ) + sim_state = state # Input validation is already done inside the forward method of the # AtomisticModel class, so we don't need to do it again here. diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index b6ff7236e..2c8818084 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -31,8 +31,6 @@ from torch_sim import transforms from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import ensure_sim_state -from torch_sim.typing import StateDict DEFAULT_SIGMA = 1.0 @@ -227,9 +225,7 @@ def __init__( self.epsilon = torch.as_tensor(epsilon, dtype=self.dtype, device=self.device) self.alpha = torch.as_tensor(alpha, dtype=self.dtype, device=self.device) - def unbatched_forward( - self, state: ts.SimState | StateDict - ) -> dict[str, torch.Tensor]: + def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: """Compute Morse potential properties for a single unbatched system. Internal implementation that processes a single, non-batched simulation state. @@ -237,9 +233,8 @@ def unbatched_forward( neighbor list construction, distance calculations, and property computation. Args: - state (SimState | StateDict): Single, non-batched simulation state or - equivalent dictionary containing atomic positions, cell vectors, - and other system information. + state (SimState): Single, non-batched simulation state containing atomic + positions, cell vectors, and other system information. Returns: dict[str, torch.Tensor]: Computed properties: @@ -256,7 +251,6 @@ def unbatched_forward( This method can work with both neighbor list and full pairwise calculations. In both cases, interactions are truncated at the cutoff distance. """ - state = ensure_sim_state(state) positions = state.positions cell = state.row_vector_cell cell = cell.squeeze() @@ -351,18 +345,15 @@ def unbatched_forward( return results - def forward( - self, state: ts.SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute Morse potential energies, forces, and stresses for a system. Main entry point for Morse potential calculations that handles batched states by dispatching each batch to the unbatched implementation and combining results. Args: - state (SimState | StateDict): Input state containing atomic positions, - cell vectors, and other system information. Can be a SimState object - or a dictionary with the same keys. + state (SimState): Input state containing atomic positions, cell vectors, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -387,8 +378,6 @@ def forward( forces = results["forces"] # Shape: [n_atoms, 3] ``` """ - state = ensure_sim_state(state) - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] properties = outputs[0] diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index e8383c4bf..a02025048 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -25,7 +25,6 @@ from torch_sim.elastic import voigt_6_to_full_3x3_stress from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState, ensure_sim_state try: @@ -55,7 +54,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from orb_models.forcefield.direct_regressor import DirectForcefieldRegressor from orb_models.forcefield.featurization_utilities import EdgeCreationMethod - from torch_sim.typing import StateDict + from torch_sim.state import SimState def cell_to_cellpar( @@ -397,18 +396,15 @@ def __init__( if self.conservative: self.implemented_properties.extend(["forces", "stress"]) - def forward( - self, state: SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Perform forward pass to compute energies, forces, and other properties. Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses. Args: - state (SimState | StateDict): State object containing positions, cells, - atomic numbers, and other system information. If a dictionary is provided, - it will be converted to a SimState. + state (SimState): State object containing positions, cells, atomic numbers, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -422,7 +418,7 @@ def forward( The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ - sim_state = ensure_sim_state(state) + sim_state = state if sim_state.device != self._device: sim_state = sim_state.to(self._device) diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index eecba4053..c227f2720 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -6,7 +6,6 @@ from torch_sim import transforms from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.typing import StateDict DEFAULT_BETA = torch.tensor(0.3) @@ -141,9 +140,7 @@ def __init__( self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) self.beta = torch.tensor(beta, dtype=self.dtype, device=self.device) - def unbatched_forward( - self, state: ts.SimState | StateDict - ) -> dict[str, torch.Tensor]: + def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: """Compute energies and forces for a single unbatched system. Internal implementation that processes a single, non-batched simulation state. @@ -157,18 +154,6 @@ def unbatched_forward( Returns: A dictionary containing the energy, forces, and stresses """ - if not isinstance(state, ts.SimState): - state_dict = state - positions_in = state_dict["positions"] - state = ts.SimState( - positions=positions_in, - masses=torch.ones_like(positions_in), - cell=state_dict["cell"], - pbc=state_dict.get("pbc", True), - atomic_numbers=state_dict["atomic_numbers"], - system_idx=state_dict.get("system_idx"), - ) - positions = state.positions cell = state.row_vector_cell @@ -250,9 +235,7 @@ def unbatched_forward( return results - def forward( - self, state: ts.SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute particle life energies and forces for a system. Main entry point for particle life calculations that handles batched states by @@ -279,18 +262,7 @@ def forward( Raises: ValueError: If batch cannot be inferred for multi-cell systems. """ - if isinstance(state, ts.SimState): - sim_state = state - else: - positions_in = state["positions"] - sim_state = ts.SimState( - positions=positions_in, - masses=torch.ones_like(positions_in), - cell=state["cell"], - pbc=state.get("pbc", True), - atomic_numbers=state["atomic_numbers"], - system_idx=state.get("system_idx"), - ) + sim_state = state if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: raise ValueError( diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 873697afc..bf0fe735d 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -12,7 +12,6 @@ from torch_sim.elastic import voigt_6_to_full_3x3_stress from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import SimState, ensure_sim_state if TYPE_CHECKING: @@ -20,7 +19,7 @@ from sevenn.nn.sequential import AtomGraphSequential - from torch_sim.typing import StateDict + from torch_sim.state import SimState try: @@ -160,18 +159,15 @@ def __init__( self.implemented_properties = ["energy", "forces", "stress"] - def forward( - self, state: SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Perform forward pass to compute energies, forces, and other properties. Takes a simulation state and computes the properties implemented by the model, such as energy, forces, and stresses. Args: - state (SimState | StateDict): State object containing positions, cells, - atomic numbers, and other system information. If a dictionary is provided, - it will be converted to a SimState. + state (SimState): State object containing positions, cells, atomic numbers, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -185,7 +181,7 @@ def forward( The state is automatically transferred to the model's device if needed. All output tensors are detached from the computation graph. """ - sim_state = ensure_sim_state(state) + sim_state = state if sim_state.device != self._device: sim_state = sim_state.to(self._device) diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 62d8d6d06..0dd1c6b65 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -47,8 +47,6 @@ from torch_sim import transforms from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import ensure_sim_state -from torch_sim.typing import StateDict DEFAULT_SIGMA = torch.tensor(1.0) @@ -277,7 +275,6 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: The soft sphere potential is purely repulsive, and forces are truncated at the cutoff distance. """ - state = ensure_sim_state(state) positions = state.positions cell = state.row_vector_cell cell = cell.squeeze() @@ -385,9 +382,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: return results - def forward( - self, state: ts.SimState | StateDict, **_kwargs - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]: """Compute soft sphere potential energies, forces, and stresses for a system. Main entry point for soft sphere potential calculations that handles batched @@ -395,9 +390,8 @@ def forward( results. Args: - state (SimState | StateDict): Input state containing atomic positions, - cell vectors, and other system information. Can be a SimState object - or a dictionary with the same keys. + state (SimState): Input state containing atomic positions, cell vectors, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -422,8 +416,6 @@ def forward( forces = results["forces"] # Shape: [n_atoms, 3] ``` """ - state = ensure_sim_state(state) - # Handle System indices if not provided if state.system_idx is None and state.cell.shape[0] > 1: raise ValueError( @@ -695,8 +687,6 @@ def unbatched_forward( particles, it looks up the appropriate parameters based on the species of the two particles. """ - state = ensure_sim_state(state) - species_idx = state.atomic_numbers.to(device=self.device, dtype=torch.long) positions = state.positions @@ -806,9 +796,7 @@ def unbatched_forward( return results - def forward( - self, state: ts.SimState | StateDict, **_kwargs - ) -> dict[str, torch.Tensor]: + def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]: """Compute soft sphere potential properties for multi-component systems. Main entry point for multi-species soft sphere calculations that handles @@ -816,9 +804,8 @@ def forward( and combining results. Args: - state (SimState | StateDict): Input state containing atomic positions, - cell vectors, and other system information. Can be a SimState object - or a dictionary with the same keys. + state (SimState): Input state containing atomic positions, cell vectors, + and other system information. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -854,8 +841,6 @@ def forward( This method requires species information either provided during initialization or included in the state object's metadata. """ - state = ensure_sim_state(state) - if state.pbc != self.pbc: raise ValueError("PBC mismatch between model and state") diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index 69d996ed6..5cccf96ca 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -21,8 +21,7 @@ import torch_sim as ts from torch_sim.optimizers import cell_filters from torch_sim.optimizers.cell_filters import CellBFGSState, frechet_cell_filter_init -from torch_sim.state import SimState, ensure_sim_state -from torch_sim.typing import StateDict +from torch_sim.state import SimState if TYPE_CHECKING: @@ -82,7 +81,7 @@ def _pad_to_dense( def bfgs_init( - state: SimState | StateDict, + state: SimState, model: "ModelInterface", *, max_step: float | torch.Tensor = 0.2, @@ -117,8 +116,6 @@ def bfgs_init( device: torch.device = model.device dtype: torch.dtype = model.dtype - state = ensure_sim_state(state) - n_systems = state.n_systems # S counts = state.n_atoms_per_system # [S] diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 2e2b55016..f780648b6 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -8,8 +8,7 @@ import torch_sim.math as tsm from torch_sim._duecredit import dcite from torch_sim.optimizers import CellFireState, cell_filters -from torch_sim.state import SimState, ensure_sim_state -from torch_sim.typing import StateDict +from torch_sim.state import SimState if TYPE_CHECKING: @@ -20,7 +19,7 @@ @dcite("10.1103/PhysRevLett.97.170201") def fire_init( - state: SimState | StateDict, + state: SimState, model: "ModelInterface", *, dt_start: float | torch.Tensor = 0.1, @@ -36,7 +35,7 @@ def fire_init( Args: model: Model that computes energies, forces, and optionally stress - state: Input state as SimState object or state parameter dict + state: Input SimState dt_start: Initial timestep per system alpha_start: Initial mixing parameter per system fire_flavor: Optimization flavor ("vv_fire" or "ase_fire") @@ -59,7 +58,6 @@ def fire_init( device: torch.device = model.device dtype: torch.dtype = model.dtype - state = ensure_sim_state(state) n_systems = state.n_systems diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index cf6f29aec..6f940ff0f 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -6,8 +6,7 @@ import torch_sim as ts from torch_sim.optimizers import cell_filters -from torch_sim.state import SimState, ensure_sim_state -from torch_sim.typing import StateDict +from torch_sim.state import SimState if TYPE_CHECKING: @@ -17,7 +16,7 @@ def gradient_descent_init( - state: SimState | StateDict, + state: SimState, model: "ModelInterface", *, cell_filter: "CellFilter | CellFilterFuncs | None" = None, @@ -41,8 +40,6 @@ def gradient_descent_init( # Import here to avoid circular imports from torch_sim.optimizers import CellOptimState, OptimState - state = ensure_sim_state(state) - # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 27d41ce96..0221f92d7 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -23,8 +23,7 @@ frechet_cell_filter_init, ) from torch_sim.optimizers.state import LBFGSState -from torch_sim.state import SimState, ensure_sim_state -from torch_sim.typing import StateDict +from torch_sim.state import SimState if TYPE_CHECKING: @@ -112,7 +111,7 @@ def _per_system_vdot( def lbfgs_init( - state: SimState | StateDict, + state: SimState, model: "ModelInterface", *, step_size: float | torch.Tensor = 0.1, @@ -133,7 +132,7 @@ def lbfgs_init( M_ext = M + 3 (extended with cell DOFs per system) Args: - state: Input state as SimState object or state parameter dict + state: Input SimState model: Model that computes energies, forces, and optionally stress step_size: Fixed per-system step length (damping factor). If using ASE mode (fixed alpha), set this to 1.0 (or your damping). @@ -161,7 +160,6 @@ def lbfgs_init( """ device, dtype = model.device, model.dtype - state = ensure_sim_state(state) n_systems = state.n_systems # S # Compute max atoms per system for per-system history storage diff --git a/torch_sim/state.py b/torch_sim/state.py index d7212f084..461963e7e 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -16,7 +16,7 @@ from torch._prims_common import DeviceLikeType import torch_sim as ts -from torch_sim.typing import PRNGLike, StateDict, StateLike +from torch_sim.typing import PRNGLike, StateLike if TYPE_CHECKING: @@ -62,13 +62,6 @@ def require_system_idx(system_idx: torch.Tensor | None) -> torch.Tensor: return system_idx -def ensure_sim_state(state: "SimState | StateDict") -> "SimState": - """Return a SimState from either SimState or StateDict input.""" - if isinstance(state, SimState): - return state - return SimState(**state) - - @dataclass(kw_only=True) class SimState: """State representation for atomistic systems with batched operations support. diff --git a/torch_sim/typing.py b/torch_sim/typing.py index fd7278e1e..114f0506b 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -48,5 +48,3 @@ class BravaisType(StrEnum): PRNGLike = int | torch.Generator | None MemoryScaling = Literal["n_atoms_x_density", "n_atoms"] -StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "system_idx"] -StateDict = dict From 05c7e489ad876192eb83cab1064dcac25bee3e09 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 3 Mar 2026 15:00:53 -0500 Subject: [PATCH 2/4] Fix pair_potential.py to remove StateDict and ensure_sim_state The pair_potential module was merged in after the StateDict removal commit, so it still referenced the deleted ensure_sim_state helper and StateDict type. Remove them to fix CI. --- torch_sim/models/pair_potential.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/torch_sim/models/pair_potential.py b/torch_sim/models/pair_potential.py index 901a23c90..6d13dbecb 100644 --- a/torch_sim/models/pair_potential.py +++ b/torch_sim/models/pair_potential.py @@ -37,14 +37,13 @@ from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import SimState, ensure_sim_state from torch_sim.transforms import compute_cell_shifts, pbc_wrap_batched if TYPE_CHECKING: from collections.abc import Callable - from torch_sim.typing import StateDict + from torch_sim.state import SimState @torch.jit.script @@ -294,7 +293,7 @@ def full_to_half_list( def _prepare_pairs( - state: SimState | StateDict, + state: SimState, *, cutoff: torch.Tensor, neighbor_list_fn: Callable, @@ -314,7 +313,7 @@ def _prepare_pairs( int, # n_systems ]: """Unpack state, build neighbor list, compute pair vectors and distances.""" - sim_state = ensure_sim_state(state) + sim_state = state positions = sim_state.positions row_cell = sim_state.row_vector_cell @@ -510,13 +509,11 @@ def __init__( self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device) self.reduce_to_half_list = reduce_to_half_list - def forward( - self, state: SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute pair-potential properties with batched tensor operations. Args: - state: Simulation state or equivalent state dict. + state: Simulation state. **_kwargs: Unused; accepted for interface compatibility. Returns: @@ -677,13 +674,11 @@ def __init__( self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device) self.reduce_to_half_list = reduce_to_half_list - def forward( - self, state: SimState | StateDict, **_kwargs: object - ) -> dict[str, torch.Tensor]: + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute forces from a direct pair force function. Args: - state: Simulation state or equivalent state dict. + state: Simulation state. **_kwargs: Unused; accepted for interface compatibility. Returns: From c1b10051fa1721ae5dd2b1f8365be2f8aba7f897 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 3 Mar 2026 15:12:05 -0500 Subject: [PATCH 3/4] =?UTF-8?q?Change=20default=20cutoff=20from=207.0=20to?= =?UTF-8?q?=206.0=20=C3=85=20in=20autobatching?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_sim/autobatching.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 81821ff00..26a6775aa 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -352,7 +352,7 @@ def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]: def calculate_memory_scalers( state: SimState, memory_scales_with: MemoryScaling = "n_atoms_x_density", - cutoff: float = 7.0, + cutoff: float = 6.0, ) -> list[float]: """Calculate a metric that estimates memory requirements for each system in a state. @@ -564,7 +564,7 @@ def __init__( model: ModelInterface, *, memory_scales_with: MemoryScaling = "n_atoms_x_density", - cutoff: float = 7.0, + cutoff: float = 6.0, max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, @@ -868,7 +868,7 @@ def __init__( model: ModelInterface, *, memory_scales_with: MemoryScaling = "n_atoms_x_density", - cutoff: float = 7.0, + cutoff: float = 6.0, max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, From e5537e1e54721e03c518a588c7cb2f21f7f19888 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 4 Mar 2026 11:14:17 -0500 Subject: [PATCH 4/4] Require nequip>=0.17.0 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ac4cabe61..59f0a201a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn[torchsim]>=0.12.1"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] -nequip = ["nequip>=0.16.2"] +nequip = ["nequip>=0.17.0"] fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] docs = [ "autodoc_pydantic==2.2.0", @@ -132,7 +132,7 @@ docstring-code-format = true [tool.codespell] check-filenames = true -ignore-words-list = ["convertor"] +ignore-words-list = ["convertor"] # codespell:ignore convertor [tool.pytest.ini_options] addopts = ["-p no:warnings"]