From c819ae476d7313b08f3ad2b5c366936518793a35 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 23 Jan 2026 11:01:42 -0500 Subject: [PATCH 1/9] replace manual initialization with from_state across integrators and optimizers --- torch_sim/integrators/npt.py | 28 +++++----------------- torch_sim/integrators/nve.py | 10 ++------ torch_sim/integrators/nvt.py | 30 ++++++------------------ torch_sim/optimizers/fire.py | 24 +++++-------------- torch_sim/optimizers/gradient_descent.py | 21 +++++------------ 5 files changed, 27 insertions(+), 86 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index f9110ccf5..4d6bd752e 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -612,17 +612,12 @@ def npt_langevin_init( ) # Create the initial state - return NPTLangevinState( - positions=state.positions, + return NPTLangevinState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, alpha=alpha, b_tau=b_tau, reference_cell=reference_cell, @@ -630,7 +625,6 @@ def npt_langevin_init( cell_velocities=cell_velocities, cell_masses=cell_masses, cell_alpha=cell_alpha, - _constraints=state.constraints, ) @@ -1421,16 +1415,12 @@ def npt_nose_hoover_init( ) # Create initial state - return NPTNoseHooverState( - positions=state.positions, + return NPTNoseHooverState.from_state( + state, momenta=momenta, energy=energy, forces=forces, - masses=state.masses, atomic_numbers=atomic_numbers, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, reference_cell=reference_cell, cell_position=cell_position, cell_momentum=cell_momentum, @@ -1439,7 +1429,6 @@ def npt_nose_hoover_init( thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, - _constraints=state.constraints, ) @@ -2315,17 +2304,12 @@ def npt_crescale_init( ) # Create the initial state - return NPTCRescaleState( - positions=state.positions, + return NPTCRescaleState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, ) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index b4db4e6c9..4880cfac6 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -57,17 +57,11 @@ def nve_init( calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( - positions=state.positions, + return MDState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, - _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index d773a9221..61cb84723 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -118,17 +118,11 @@ def nvt_langevin_init( "momenta", calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( - positions=state.positions, + return MDState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, - _constraints=state.constraints, ) @@ -316,19 +310,14 @@ def nvt_nose_hoover_init( ) # n_atoms * n_dimensions # Initialize state - return NVTNoseHooverState( - positions=state.positions, + return NVTNoseHooverState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, atomic_numbers=atomic_numbers, - system_idx=state.system_idx, chain=chain_fns.initialize(dof_per_system, KE, kT), - _chain_fns=chain_fns, # Store the chain functions - _constraints=state.constraints, + _chain_fns=chain_fns, ) @@ -603,16 +592,11 @@ def nvt_vrescale_init( calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return NVTVRescaleState( - positions=state.positions, + return NVTVRescaleState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, ) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index e69a4955f..763cb646c 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -72,24 +72,12 @@ def fire_init( forces = model_output["forces"] stress = model_output.get("stress") - # Common state arguments - common_args = { - # Copy SimState attributes - "positions": state.positions.clone(), - "masses": state.masses.clone(), - "cell": state.cell.clone(), - "atomic_numbers": state.atomic_numbers.clone(), - "system_idx": state.system_idx.clone(), - "_constraints": state.constraints, - "pbc": state.pbc, - "charge": state.charge.clone(), - "spin": state.spin.clone(), - # Optimization state + # FIRE-specific additional attributes + fire_attrs = { "forces": forces, "energy": energy, "stress": stress, "velocities": torch.full(state.positions.shape, torch.nan, **tensor_args), - # FIRE parameters "dt": torch.full((n_systems,), dt_start, **tensor_args), "alpha": torch.full((n_systems,), alpha_start, **tensor_args), "n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), @@ -97,9 +85,9 @@ def fire_init( if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - common_args["reference_cell"] = state.cell.clone() - common_args["cell_filter"] = cell_filter_funcs - cell_state = CellFireState(**common_args) + fire_attrs["reference_cell"] = state.cell.clone() + fire_attrs["cell_filter"] = cell_filter_funcs + cell_state = CellFireState.from_state(state, **fire_attrs) # Initialize cell-specific attributes init_fn(cell_state, model, **filter_kwargs) @@ -111,7 +99,7 @@ def fire_init( return cell_state # Create regular FireState without cell optimization - return FireState(**common_args) + return FireState.from_state(state, **fire_attrs) def fire_step( diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 283fe7c80..23a51a0ed 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -50,34 +50,25 @@ def gradient_descent_init( forces = model_output["forces"] stress = model_output.get("stress") - # Common state arguments - common_args = { - "positions": state.positions, + # Optimizer-specific additional attributes + optim_attrs = { "forces": forces, "energy": energy, "stress": stress, - "masses": state.masses, - "cell": state.cell, - "pbc": state.pbc, - "atomic_numbers": state.atomic_numbers, - "system_idx": state.system_idx, - "_constraints": state.constraints, - "charge": state.charge, - "spin": state.spin, } if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - common_args["reference_cell"] = state.cell.clone() - common_args["cell_filter"] = cell_filter_funcs - cell_state = CellOptimState(**common_args) + optim_attrs["reference_cell"] = state.cell.clone() + optim_attrs["cell_filter"] = cell_filter_funcs + cell_state = CellOptimState.from_state(state, **optim_attrs) # Initialize cell-specific attributes init_fn(cell_state, model, **filter_kwargs) return cell_state # Create regular OptimState without cell optimization - return OptimState(**common_args) + return OptimState.from_state(state, **optim_attrs) def gradient_descent_step( From c782610b1d4f0ed2c8d49b63fce50018cb7c7e35 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 23 Jan 2026 11:34:08 -0500 Subject: [PATCH 2/9] missed actually relevant state change --- torch_sim/state.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index e97101d73..890cf51e0 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -369,9 +369,10 @@ def clone(self) -> Self: def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: """Create a new state from an existing state with additional attributes. - This method copies all attributes from the source state and adds any additional - attributes needed for the target state class. It's useful for converting between - different state types (e.g., SimState to MDState). + This method copies attributes from the source state that are valid for the + target state class, and adds any additional attributes needed. It supports + upcasting (SimState -> MDState), downcasting (MDState -> SimState), and + cross-casting (MDState -> OptimState) between state types. Args: state: Source state to copy base attributes from @@ -389,13 +390,22 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: ... momenta=torch.zeros_like(sim_state.positions), ... ) """ - # Copy all attributes from the source state + # Get attributes that the TARGET class accepts + target_attrs = ( + cls._atom_attributes + | cls._system_attributes + | cls._global_attributes + | {"_constraints"} + ) + + # Copy only attributes that exist on BOTH source AND target attrs = {} for attr_name, attr_value in state.attributes.items(): - if isinstance(attr_value, torch.Tensor): - attrs[attr_name] = attr_value.clone() - else: - attrs[attr_name] = copy.deepcopy(attr_value) + if attr_name in target_attrs: + if isinstance(attr_value, torch.Tensor): + attrs[attr_name] = attr_value.clone() + else: + attrs[attr_name] = copy.deepcopy(attr_value) # Add/override with additional attributes attrs.update(additional_attrs) From 5b8d5bcff54100be96337c5c53c4720e4c2a819e Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 23 Jan 2026 12:24:14 -0500 Subject: [PATCH 3/9] make sure charge and spin are correctly passed in fairchem model, they were missing! --- torch_sim/models/fairchem.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 83d783e68..602cf7529 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -222,10 +222,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict: atoms.info["spin"] = sim_state.spin[idx].item() # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) + # r_data_keys must be passed for charge/spin to be read from atoms.info if self.task_name is None: - atomic_data = AtomicData.from_ase(atoms) + atomic_data = AtomicData.from_ase(atoms, r_data_keys=["charge", "spin"]) else: - atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name) + atomic_data = AtomicData.from_ase( + atoms, task_name=self.task_name, r_data_keys=["charge", "spin"] + ) atomic_data_list.append(atomic_data) # Create batch for efficient inference From 2735e961731d6beb1227ebdcd5c6558a46b6bbea Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 23 Jan 2026 12:24:30 -0500 Subject: [PATCH 4/9] pass charge and spin in to_atoms --- torch_sim/io.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch_sim/io.py b/torch_sim/io.py index 27be6b1c4..e2fae5db9 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -41,6 +41,7 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: Notes: - Output positions and cell will be in Å - Output masses will be in amu + - Charge and spin are preserved in atoms.info if present in the state """ try: from ase import Atoms @@ -55,6 +56,10 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: system_indices = state.system_idx.detach().cpu().numpy() pbc = state.pbc.detach().cpu().numpy() + # Extract charge and spin if available (per-system attributes) + charge = state.charge.detach().cpu().numpy() if state.charge is not None else None + spin = state.spin.detach().cpu().numpy() if state.spin is not None else None + atoms_list = [] for sys_idx in np.unique(system_indices): mask = system_indices == sys_idx @@ -68,6 +73,13 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: atoms = Atoms( symbols=symbols, positions=system_positions, cell=system_cell, pbc=pbc ) + + # Preserve charge and spin in atoms.info (as integers for FairChem compatibility) + if charge is not None: + atoms.info["charge"] = int(charge[sys_idx].item()) + if spin is not None: + atoms.info["spin"] = int(spin[sys_idx].item()) + atoms_list.append(atoms) return atoms_list From 3a552efd1b615c78cc9f51e44e9899d9f23dddd7 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 27 Jan 2026 10:01:53 -0500 Subject: [PATCH 5/9] add from_state to StaticState and add _all_attributes --- torch_sim/runners.py | 9 ++------- torch_sim/state.py | 22 ++++++++++++---------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 21d10be8e..ff445bf9e 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -762,13 +762,8 @@ class StaticState(SimState): ) model_outputs = model(sub_state) - static_state = StaticState( - positions=sub_state.positions, - masses=sub_state.masses, - cell=sub_state.cell, - pbc=sub_state.pbc, - atomic_numbers=sub_state.atomic_numbers, - system_idx=sub_state.system_idx, + static_state = StaticState.from_state( + state=sub_state, energy=model_outputs["energy"], forces=( model_outputs["forces"] diff --git a/torch_sim/state.py b/torch_sim/state.py index 890cf51e0..38621933a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -183,6 +183,17 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") + @classmethod + @property + def _all_attributes(cls) -> set[str]: + """Get all attributes of the SimState.""" + return ( + cls._atom_attributes + | cls._system_attributes + | cls._global_attributes + | {"_constraints"} + ) + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, @@ -390,18 +401,9 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: ... momenta=torch.zeros_like(sim_state.positions), ... ) """ - # Get attributes that the TARGET class accepts - target_attrs = ( - cls._atom_attributes - | cls._system_attributes - | cls._global_attributes - | {"_constraints"} - ) - - # Copy only attributes that exist on BOTH source AND target attrs = {} for attr_name, attr_value in state.attributes.items(): - if attr_name in target_attrs: + if attr_name in cls._all_attributes: if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.clone() else: From 84a927f6786a3873fefe4377724c45eb995319ba Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 27 Jan 2026 10:09:27 -0500 Subject: [PATCH 6/9] clean up spin charge def --- torch_sim/io.py | 4 ++-- torch_sim/state.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torch_sim/io.py b/torch_sim/io.py index e2fae5db9..fdd75893c 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -57,8 +57,8 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: pbc = state.pbc.detach().cpu().numpy() # Extract charge and spin if available (per-system attributes) - charge = state.charge.detach().cpu().numpy() if state.charge is not None else None - spin = state.spin.detach().cpu().numpy() if state.spin is not None else None + charge = state.charge.detach().cpu().numpy() + spin = state.spin.detach().cpu().numpy() atoms_list = [] for sys_idx in np.unique(system_indices): diff --git a/torch_sim/state.py b/torch_sim/state.py index 38621933a..600e22e35 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -96,14 +96,13 @@ class SimState: if TYPE_CHECKING: @property - def system_idx(self) -> torch.Tensor: - """A getter for system_idx that tells type checkers it's always defined.""" - return self.system_idx - + def system_idx(self) -> torch.Tensor: ... # noqa: D102 + @property + def pbc(self) -> torch.Tensor: ... # noqa: D102 + @property + def charge(self) -> torch.Tensor: ... # noqa: D102 @property - def pbc(self) -> torch.Tensor: - """A getter for pbc that tells type checkers it's always defined.""" - return self.pbc + def spin(self) -> torch.Tensor: ... # noqa: D102 _atom_attributes: ClassVar[set[str]] = { "positions", From c77b2fcebafeb337b9d7039d36ed0cb09934a413 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 27 Jan 2026 10:14:13 -0500 Subject: [PATCH 7/9] try swapping decorators? --- torch_sim/state.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 600e22e35..880156eda 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -182,8 +182,8 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") - @classmethod @property + @classmethod def _all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" return ( @@ -234,13 +234,7 @@ def volume(self) -> torch.Tensor: @property def attributes(self) -> dict[str, torch.Tensor]: """Get all public attributes of the state.""" - return { - attr: getattr(self, attr) - for attr in self._atom_attributes - | self._system_attributes - | self._global_attributes - | {"_constraints"} - } + return {attr: getattr(self, attr) for attr in self._all_attributes} @property def column_vector_cell(self) -> torch.Tensor: From 0dbf78ca4a86eb383e8177e2029afd09bda76662 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 27 Jan 2026 10:28:27 -0500 Subject: [PATCH 8/9] switch back --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 880156eda..18c7020bc 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -182,8 +182,8 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") - @property @classmethod + @property def _all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" return ( From a1da3818900dbfd0fb219d8a3e67ce939a2ec248 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 27 Jan 2026 10:34:25 -0500 Subject: [PATCH 9/9] change name and make method --- torch_sim/state.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 18c7020bc..ec2c589de 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -183,8 +183,7 @@ def __post_init__(self) -> None: # noqa: C901 raise ValueError("All tensors must be on the same device") @classmethod - @property - def _all_attributes(cls) -> set[str]: + def _get_all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" return ( cls._atom_attributes @@ -234,7 +233,7 @@ def volume(self) -> torch.Tensor: @property def attributes(self) -> dict[str, torch.Tensor]: """Get all public attributes of the state.""" - return {attr: getattr(self, attr) for attr in self._all_attributes} + return {attr: getattr(self, attr) for attr in self._get_all_attributes()} @property def column_vector_cell(self) -> torch.Tensor: @@ -396,7 +395,7 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: """ attrs = {} for attr_name, attr_value in state.attributes.items(): - if attr_name in cls._all_attributes: + if attr_name in cls._get_all_attributes(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.clone() else: