From 74a22a227dc7016374f7559765e01916f9acfef3 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 13:15:18 -0700 Subject: [PATCH 1/3] docs --- torch_sim/state.py | 55 +++++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index ce21ef9bd..d9f98ef25 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -7,7 +7,7 @@ import copy import importlib import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Self import torch @@ -22,7 +22,7 @@ from pymatgen.core import Structure -@dataclass +@dataclass(init=False) class SimState: """State representation for atomistic systems with batched operations support. @@ -47,9 +47,8 @@ class SimState: used by ASE. pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) - system_idx (torch.Tensor, optional): Maps each atom index to its system index. - Has shape (n_atoms,), defaults to None, must be unique consecutive - integers starting from 0 + system_idx (torch.Tensor): Maps each atom index to its system index. + Has shape (n_atoms,), must be unique consecutive integers starting from 0. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary @@ -81,10 +80,35 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor | None = field(default=None, kw_only=True) + system_idx: torch.Tensor + + def __init__( + self, + positions: torch.Tensor, + masses: torch.Tensor, + cell: torch.Tensor, + pbc: bool, # noqa: FBT001 # TODO(curtis): maybe make the constructor be keyword-only (it can be easy to confuse positions vs masses, etc.) + atomic_numbers: torch.Tensor, + system_idx: torch.Tensor | None = None, + ) -> None: + """Initialize the SimState and validate the arguments. + + Args: + positions (torch.Tensor): Atomic positions with shape (n_atoms, 3) + masses (torch.Tensor): Atomic masses with shape (n_atoms,) + cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3). + pbc (bool): Boolean indicating whether to use periodic boundary conditions + atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) + system_idx (torch.Tensor | None): Maps each atom index to its system index. + Has shape (n_atoms,), must be unique consecutive integers starting from 0. + If not provided, it is initialized to zeros. + """ + self.positions = positions + self.masses = masses + self.cell = cell + self.pbc = pbc + self.atomic_numbers = atomic_numbers - def __post_init__(self) -> None: - """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way @@ -107,17 +131,12 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.system_idx is None: - self.cell = self.cell.unsqueeze(0) - - if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_systems, 3, 3)") - - if self.system_idx is None: + if system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: + self.system_idx = system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. @@ -125,6 +144,12 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(self.system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.cell.ndim != 3 and self.system_idx is None: + self.cell = self.cell.unsqueeze(0) + + if self.cell.shape[-2:] != (3, 3): + raise ValueError("Cell must have shape (n_systems, 3, 3)") + if self.cell.shape[0] != self.n_systems: raise ValueError( f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" From 812da25951e0cb5e25da79c503d9616a0b14f26a Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 13:29:06 -0700 Subject: [PATCH 2/3] fix behaviour --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index d9f98ef25..5240d0946 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -144,7 +144,7 @@ def __init__( if not torch.all(counts == torch.bincount(self.system_idx)): raise ValueError("System indices must be unique consecutive integers") - if self.cell.ndim != 3 and self.system_idx is None: + if self.cell.ndim != 3 and system_idx is None: self.cell = self.cell.unsqueeze(0) if self.cell.shape[-2:] != (3, 3): From fe07159d395d74b75b9591a587daacc80929f910 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 16:05:15 -0700 Subject: [PATCH 3/3] add system idx as param --- torch_sim/integrators/nvt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index ff9b7b4bc..19a811c8b 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -389,6 +389,7 @@ def nvt_nose_hoover_init( cell=state.cell, pbc=state.pbc, atomic_numbers=atomic_numbers, + system_idx=state.system_idx, chain=chain_fns.initialize(total_dof, KE, kT), _chain_fns=chain_fns, # Store the chain functions )