Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/scripts/1_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.telemetry import configure_logging, get_logger
Expand Down Expand Up @@ -94,8 +95,8 @@
# Masses for Argon (39.948 amu)
masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype)

# State dict
state = dict(
# SimState
state = ts.SimState(
positions=positions,
masses=masses,
cell=cell.unsqueeze(0),
Expand Down Expand Up @@ -178,7 +179,7 @@

# Now we can pass them to the model
results = batched_model(
dict(
ts.SimState(
positions=positions,
masses=masses_si,
cell=cell,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"]
orb = ["orb-models>=0.5.2"]
sevenn = ["sevenn[torchsim]>=0.12.1"]
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
nequip = ["nequip>=0.16.2"]
nequip = ["nequip>=0.17.0"]
fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"]
docs = [
"autodoc_pydantic==2.2.0",
Expand Down Expand Up @@ -132,7 +132,7 @@ docstring-code-format = true

[tool.codespell]
check-filenames = true
ignore-words-list = ["convertor"]
ignore-words-list = ["convertor"] # codespell:ignore convertor

[tool.pytest.ini_options]
addopts = ["-p no:warnings"]
Expand Down
5 changes: 1 addition & 4 deletions tests/test_fix_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch_sim.models.interface import ModelInterface
from torch_sim.models.lennard_jones import UnbatchedLennardJonesModel
from torch_sim.symmetrize import get_symmetry_datasets
from torch_sim.typing import StateDict


pytest.importorskip("moyopy")
Expand Down Expand Up @@ -100,9 +99,7 @@ def __init__(
self._compute_stress = model.compute_stress
self._compute_forces = model.compute_forces

def forward(
self, state: ts.SimState | StateDict, **kwargs: object
) -> dict[str, torch.Tensor]:
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
"""Forward pass with added noise."""
results = self.model(state, **kwargs)
for key in ("forces", "stress"):
Expand Down
17 changes: 4 additions & 13 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
from collections.abc import Callable
from dataclasses import fields
from functools import partial
from typing import Any, get_args

Expand Down Expand Up @@ -572,13 +571,9 @@ def test_simple_optimizer_init_with_dict(
ar_supercell_sim_state: SimState,
lj_model: ModelInterface,
) -> None:
"""Test simple optimizer init_fn with a SimState dictionary."""
state_dict = {
field.name: getattr(ar_supercell_sim_state, field.name)
for field in fields(ar_supercell_sim_state)
}
"""Test simple optimizer init_fn with a SimState."""
init_fn, _ = ts.OPTIM_REGISTRY[optimizer_fn]
opt_state = init_fn(model=lj_model, state=state_dict)
opt_state = init_fn(model=lj_model, state=ar_supercell_sim_state)
assert isinstance(opt_state, expected_state_type)
assert opt_state.energy is not None
assert opt_state.forces is not None
Expand Down Expand Up @@ -822,15 +817,11 @@ def test_cell_optimizer_init_with_dict_and_cell_factor(
ar_supercell_sim_state: SimState,
lj_model: ModelInterface,
) -> None:
"""Test cell optimizer init_fn with dict state and explicit cell_factor."""
state_dict = {
f.name: getattr(ar_supercell_sim_state, f.name)
for f in fields(ar_supercell_sim_state)
}
"""Test cell optimizer init_fn with explicit cell_factor."""
init_fn, _ = ts.OPTIM_REGISTRY[optimizer_fn]
opt_state = init_fn(
model=lj_model,
state=state_dict,
state=ar_supercell_sim_state,
cell_factor=cell_factor_val,
cell_filter=cell_filter,
)
Expand Down
6 changes: 3 additions & 3 deletions torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]:
def calculate_memory_scalers(
state: SimState,
memory_scales_with: MemoryScaling = "n_atoms_x_density",
cutoff: float = 7.0,
cutoff: float = 6.0,
) -> list[float]:
"""Calculate a metric that estimates memory requirements for each system in a state.

Expand Down Expand Up @@ -564,7 +564,7 @@ def __init__(
model: ModelInterface,
*,
memory_scales_with: MemoryScaling = "n_atoms_x_density",
cutoff: float = 7.0,
cutoff: float = 6.0,
max_memory_scaler: float | None = None,
max_atoms_to_try: int = 500_000,
memory_scaling_factor: float = 1.6,
Expand Down Expand Up @@ -868,7 +868,7 @@ def __init__(
model: ModelInterface,
*,
memory_scales_with: MemoryScaling = "n_atoms_x_density",
cutoff: float = 7.0,
cutoff: float = 6.0,
max_memory_scaler: float | None = None,
max_atoms_to_try: int = 500_000,
memory_scaling_factor: float = 1.6,
Expand Down
20 changes: 7 additions & 13 deletions torch_sim/integrators/npt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
)
from torch_sim.integrators.nvt import _vrescale_update
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState, ensure_sim_state
from torch_sim.typing import StateDict
from torch_sim.state import SimState


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -506,7 +505,7 @@ def _compute_cell_force(


def npt_langevin_init(
state: SimState | StateDict,
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
Expand All @@ -528,8 +527,7 @@ def npt_langevin_init(
Args:
model (ModelInterface): Neural network model that computes energies, forces,
and stress. Must return a dict with 'energy', 'forces', and 'stress' keys.
state (MDState | StateDict): Either a MDState object or a dictionary
containing positions, masses, cell, pbc
state (SimState): SimState containing positions, masses, cell, pbc
kT (torch.Tensor): Target temperature in energy units, either scalar or
with shape [n_systems]
dt (torch.Tensor): Integration timestep, either scalar or shape [n_systems]
Expand Down Expand Up @@ -565,8 +563,6 @@ def npt_langevin_init(
kT = torch.as_tensor(kT, device=device, dtype=dtype)
dt = torch.as_tensor(dt, device=device, dtype=dtype)

state = ensure_sim_state(state)

if alpha.ndim == 0:
alpha = alpha.expand(state.n_systems)
if cell_alpha.ndim == 0:
Expand Down Expand Up @@ -1284,7 +1280,7 @@ def _npt_nose_hoover_inner_step(


def npt_nose_hoover_init(
state: SimState | StateDict,
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
Expand All @@ -1307,7 +1303,7 @@ def npt_nose_hoover_init(

Args:
model (ModelInterface): Model to compute forces and energies
state: Initial system state as MDState or dict containing positions, masses,
state: Initial system state as SimState containing positions, masses,
cell, and PBC information
kT: Target temperature in energy units
external_pressure: Target external pressure
Expand Down Expand Up @@ -1339,7 +1335,6 @@ def npt_nose_hoover_init(
- Cell dynamics use logarithmic coordinates for volume updates
- All cell properties are properly initialized with batch dimensions
"""
state = ensure_sim_state(state)
device, dtype = state.device, state.dtype
dt_tensor = torch.as_tensor(dt, device=device, dtype=dtype)
kT_tensor = torch.as_tensor(kT, device=device, dtype=dtype)
Expand Down Expand Up @@ -2305,7 +2300,7 @@ def npt_crescale_isotropic_step(


def npt_crescale_init(
state: SimState | StateDict,
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
Expand All @@ -2325,7 +2320,7 @@ def npt_crescale_init(
To seed the RNG set ``state.rng = seed`` before calling.

Args:
state: Initial system state as MDState or dict containing positions, masses,
state: Initial system state as SimState containing positions, masses,
cell, and PBC information
model (ModelInterface): Model to compute forces and energies
kT: Target temperature in energy units
Expand All @@ -2334,7 +2329,6 @@ def npt_crescale_init(
isothermal_compressibility: Isothermal compressibility of the system.
"""
device, dtype = model.device, model.dtype
state = ensure_sim_state(state)

# Convert all parameters to tensors with correct device and dtype
dt = torch.as_tensor(dt, device=device, dtype=dtype)
Expand Down
11 changes: 4 additions & 7 deletions torch_sim/integrators/nve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
position_step,
)
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState, ensure_sim_state
from torch_sim.typing import StateDict
from torch_sim.state import SimState


def nve_init(
state: SimState | StateDict,
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
Expand All @@ -33,8 +32,8 @@ def nve_init(
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: Either a SimState object or a dictionary containing positions,
masses, cell, pbc, and other required state variables
state: SimState containing positions, masses, cell, pbc, and other
required state variables
kT: Temperature in energy units for initializing momenta,
scalar or with shape [n_systems]

Expand All @@ -46,8 +45,6 @@ def nve_init(
- Initial velocities sampled from Maxwell-Boltzmann distribution
- Time integration error scales as O(dt²)
"""
state = ensure_sim_state(state)

model_output = model(state)

momenta = getattr(
Expand Down
24 changes: 9 additions & 15 deletions torch_sim/integrators/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
velocity_verlet_step,
)
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState, ensure_sim_state
from torch_sim.typing import StateDict
from torch_sim.state import SimState


def _ou_step(
Expand Down Expand Up @@ -85,7 +84,7 @@ def _ou_step(


def nvt_langevin_init(
state: SimState | StateDict,
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
Expand All @@ -102,8 +101,8 @@ def nvt_langevin_init(
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: Either a SimState object or a dictionary containing positions,
masses, cell, pbc, and other required state vars
state: SimState containing positions, masses, cell, pbc, and other
required state vars
kT: Temperature in energy units for initializing momenta,
either scalar or with shape [n_systems]

Expand All @@ -116,8 +115,6 @@ def nvt_langevin_init(
at the specified temperature. This provides a proper thermal initial
state for the subsequent Langevin dynamics.
"""
state = ensure_sim_state(state)

model_output = model(state)

momenta = getattr(
Expand Down Expand Up @@ -251,7 +248,7 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor:


def nvt_nose_hoover_init(
state: SimState | StateDict,
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
Expand All @@ -272,7 +269,7 @@ def nvt_nose_hoover_init(
To seed the RNG set ``state.rng = seed`` before calling.

Args:
state: Initial system state as SimState or dict
state: Initial system state as SimState
model: Neural network model that computes energies and forces
kT: Target temperature in energy units
dt: Integration timestep
Expand All @@ -292,7 +289,6 @@ def nvt_nose_hoover_init(
- Chain variables evolve to maintain target temperature
- Time-reversible when integrated with appropriate algorithms
"""
state = ensure_sim_state(state)
dt_tensor = torch.as_tensor(dt, device=state.device, dtype=state.dtype)
kT_tensor = torch.as_tensor(kT, device=state.device, dtype=state.dtype)
tau_tensor = torch.as_tensor(
Expand Down Expand Up @@ -570,7 +566,7 @@ def _vrescale_update[T: MDState](


def nvt_vrescale_init(
state: SimState | StateDict,
state: SimState,
model: ModelInterface,
*,
kT: float | torch.Tensor,
Expand All @@ -588,8 +584,8 @@ def nvt_vrescale_init(
Args:
model: Neural network model that computes energies and forces.
Must return a dict with 'energy' and 'forces' keys.
state: Either a SimState object or a dictionary containing positions,
masses, cell, pbc, and other required state vars
state: SimState containing positions, masses, cell, pbc, and other
required state vars
kT: Temperature in energy units for initializing momenta,
either scalar or with shape [n_systems]

Expand All @@ -602,8 +598,6 @@ def nvt_vrescale_init(
at the specified temperature. The V-Rescale thermostat provides proper
canonical sampling through stochastic velocity rescaling.
"""
state = ensure_sim_state(state)

model_output = model(state)

momenta = getattr(
Expand Down
14 changes: 5 additions & 9 deletions torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch

from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState, ensure_sim_state


try:
Expand All @@ -45,7 +44,7 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
if typing.TYPE_CHECKING:
from collections.abc import Callable

from torch_sim.typing import StateDict
from torch_sim.state import SimState


class FairChemModel(ModelInterface):
Expand Down Expand Up @@ -169,15 +168,12 @@ def device(self) -> torch.device:
"""Return the device where the model is located."""
return self._device

def forward(
self, state: SimState | StateDict, **_kwargs: object
) -> dict[str, torch.Tensor]:
def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]:
"""Compute energies, forces, and other properties.

Args:
state (SimState | StateDict): State object containing positions, cells,
atomic numbers, and other system information. If a dictionary is provided,
it will be converted to a SimState.
state (SimState): State object containing positions, cells, atomic numbers,
and other system information.
**_kwargs: Unused; accepted for interface compatibility.

Returns:
Expand All @@ -186,7 +182,7 @@ def forward(
- forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3]
"""
sim_state = ensure_sim_state(state)
sim_state = state

if sim_state.device != self._device:
sim_state = sim_state.to(self._device)
Expand Down
Loading
Loading