Skip to content
28 changes: 6 additions & 22 deletions torch_sim/integrators/npt.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,25 +612,19 @@ def npt_langevin_init(
)

# Create the initial state
return NPTLangevinState(
positions=state.positions,
return NPTLangevinState.from_state(
state,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
stress=model_output["stress"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
atomic_numbers=state.atomic_numbers,
alpha=alpha,
b_tau=b_tau,
reference_cell=reference_cell,
cell_positions=cell_positions,
cell_velocities=cell_velocities,
cell_masses=cell_masses,
cell_alpha=cell_alpha,
_constraints=state.constraints,
)


Expand Down Expand Up @@ -1421,16 +1415,12 @@ def npt_nose_hoover_init(
)

# Create initial state
return NPTNoseHooverState(
positions=state.positions,
return NPTNoseHooverState.from_state(
state,
momenta=momenta,
energy=energy,
forces=forces,
masses=state.masses,
atomic_numbers=atomic_numbers,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
reference_cell=reference_cell,
cell_position=cell_position,
cell_momentum=cell_momentum,
Expand All @@ -1439,7 +1429,6 @@ def npt_nose_hoover_init(
thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT),
barostat_fns=barostat_fns,
thermostat_fns=thermostat_fns,
_constraints=state.constraints,
)


Expand Down Expand Up @@ -2315,17 +2304,12 @@ def npt_crescale_init(
)

# Create the initial state
return NPTCRescaleState(
positions=state.positions,
return NPTCRescaleState.from_state(
state,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
stress=model_output["stress"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
atomic_numbers=state.atomic_numbers,
tau_p=tau_p,
isothermal_compressibility=isothermal_compressibility,
)
10 changes: 2 additions & 8 deletions torch_sim/integrators/nve.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,11 @@ def nve_init(
calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed),
)

return MDState(
positions=state.positions,
return MDState.from_state(
state,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
atomic_numbers=state.atomic_numbers,
_constraints=state.constraints,
)


Expand Down
30 changes: 7 additions & 23 deletions torch_sim/integrators/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,11 @@ def nvt_langevin_init(
"momenta",
calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed),
)
return MDState(
positions=state.positions,
return MDState.from_state(
state,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
atomic_numbers=state.atomic_numbers,
_constraints=state.constraints,
)


Expand Down Expand Up @@ -316,19 +310,14 @@ def nvt_nose_hoover_init(
) # n_atoms * n_dimensions

# Initialize state
return NVTNoseHooverState(
positions=state.positions,
return NVTNoseHooverState.from_state(
state,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
atomic_numbers=atomic_numbers,
system_idx=state.system_idx,
chain=chain_fns.initialize(dof_per_system, KE, kT),
_chain_fns=chain_fns, # Store the chain functions
_constraints=state.constraints,
_chain_fns=chain_fns,
)


Expand Down Expand Up @@ -603,16 +592,11 @@ def nvt_vrescale_init(
calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed),
)

return NVTVRescaleState(
positions=state.positions,
return NVTVRescaleState.from_state(
state,
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
masses=state.masses,
cell=state.cell,
pbc=state.pbc,
system_idx=state.system_idx,
atomic_numbers=state.atomic_numbers,
)


Expand Down
12 changes: 12 additions & 0 deletions torch_sim/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]:
Notes:
- Output positions and cell will be in Å
- Output masses will be in amu
- Charge and spin are preserved in atoms.info if present in the state
"""
try:
from ase import Atoms
Expand All @@ -55,6 +56,10 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]:
system_indices = state.system_idx.detach().cpu().numpy()
pbc = state.pbc.detach().cpu().numpy()

# Extract charge and spin if available (per-system attributes)
charge = state.charge.detach().cpu().numpy()
spin = state.spin.detach().cpu().numpy()

atoms_list = []
for sys_idx in np.unique(system_indices):
mask = system_indices == sys_idx
Expand All @@ -68,6 +73,13 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]:
atoms = Atoms(
symbols=symbols, positions=system_positions, cell=system_cell, pbc=pbc
)

# Preserve charge and spin in atoms.info (as integers for FairChem compatibility)
if charge is not None:
atoms.info["charge"] = int(charge[sys_idx].item())
Copy link
Copy Markdown
Collaborator

@curtischong curtischong Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change int to float since we use float in other places in the code when talking about charges/spins

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should probably change to int it in the other places then. charge and spin are physically integer values.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should too. it's a bit annoying how charges is a list of float in ASE https://ase-lib.org/ase/atoms.html

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spin are physically integer values.

https://en.wikipedia.org/wiki/Fractionalization .... but here yes

if spin is not None:
atoms.info["spin"] = int(spin[sys_idx].item())

atoms_list.append(atoms)

return atoms_list
Expand Down
7 changes: 5 additions & 2 deletions torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict:
atoms.info["spin"] = sim_state.spin[idx].item()

# Convert ASE Atoms to AtomicData (task_name only applies to UMA models)
# r_data_keys must be passed for charge/spin to be read from atoms.info
if self.task_name is None:
atomic_data = AtomicData.from_ase(atoms)
atomic_data = AtomicData.from_ase(atoms, r_data_keys=["charge", "spin"])
else:
atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name)
atomic_data = AtomicData.from_ase(
atoms, task_name=self.task_name, r_data_keys=["charge", "spin"]
)
atomic_data_list.append(atomic_data)

# Create batch for efficient inference
Expand Down
24 changes: 6 additions & 18 deletions torch_sim/optimizers/fire.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,34 +72,22 @@ def fire_init(
forces = model_output["forces"]
stress = model_output.get("stress")

# Common state arguments
common_args = {
# Copy SimState attributes
"positions": state.positions.clone(),
"masses": state.masses.clone(),
"cell": state.cell.clone(),
"atomic_numbers": state.atomic_numbers.clone(),
"system_idx": state.system_idx.clone(),
"_constraints": state.constraints,
"pbc": state.pbc,
"charge": state.charge.clone(),
"spin": state.spin.clone(),
# Optimization state
# FIRE-specific additional attributes
fire_attrs = {
"forces": forces,
"energy": energy,
"stress": stress,
"velocities": torch.full(state.positions.shape, torch.nan, **tensor_args),
# FIRE parameters
"dt": torch.full((n_systems,), dt_start, **tensor_args),
"alpha": torch.full((n_systems,), alpha_start, **tensor_args),
"n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32),
}

if cell_filter is not None: # Create cell optimization state
cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter)
common_args["reference_cell"] = state.cell.clone()
common_args["cell_filter"] = cell_filter_funcs
cell_state = CellFireState(**common_args)
fire_attrs["reference_cell"] = state.cell.clone()
fire_attrs["cell_filter"] = cell_filter_funcs
cell_state = CellFireState.from_state(state, **fire_attrs)

# Initialize cell-specific attributes
init_fn(cell_state, model, **filter_kwargs)
Expand All @@ -111,7 +99,7 @@ def fire_init(

return cell_state
# Create regular FireState without cell optimization
return FireState(**common_args)
return FireState.from_state(state, **fire_attrs)


def fire_step(
Expand Down
21 changes: 6 additions & 15 deletions torch_sim/optimizers/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,25 @@ def gradient_descent_init(
forces = model_output["forces"]
stress = model_output.get("stress")

# Common state arguments
common_args = {
"positions": state.positions,
# Optimizer-specific additional attributes
optim_attrs = {
"forces": forces,
"energy": energy,
"stress": stress,
"masses": state.masses,
"cell": state.cell,
"pbc": state.pbc,
"atomic_numbers": state.atomic_numbers,
"system_idx": state.system_idx,
"_constraints": state.constraints,
"charge": state.charge,
"spin": state.spin,
}

if cell_filter is not None: # Create cell optimization state
cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter)
common_args["reference_cell"] = state.cell.clone()
common_args["cell_filter"] = cell_filter_funcs
cell_state = CellOptimState(**common_args)
optim_attrs["reference_cell"] = state.cell.clone()
optim_attrs["cell_filter"] = cell_filter_funcs
cell_state = CellOptimState.from_state(state, **optim_attrs)

# Initialize cell-specific attributes
init_fn(cell_state, model, **filter_kwargs)

return cell_state
# Create regular OptimState without cell optimization
return OptimState(**common_args)
return OptimState.from_state(state, **optim_attrs)


def gradient_descent_step(
Expand Down
9 changes: 2 additions & 7 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,13 +762,8 @@ class StaticState(SimState):
)

model_outputs = model(sub_state)
static_state = StaticState(
positions=sub_state.positions,
masses=sub_state.masses,
cell=sub_state.cell,
pbc=sub_state.pbc,
atomic_numbers=sub_state.atomic_numbers,
system_idx=sub_state.system_idx,
static_state = StaticState.from_state(
state=sub_state,
energy=model_outputs["energy"],
forces=(
model_outputs["forces"]
Expand Down
48 changes: 26 additions & 22 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,13 @@ class SimState:
if TYPE_CHECKING:

@property
def system_idx(self) -> torch.Tensor:
"""A getter for system_idx that tells type checkers it's always defined."""
return self.system_idx

def system_idx(self) -> torch.Tensor: ... # noqa: D102
@property
def pbc(self) -> torch.Tensor: ... # noqa: D102
@property
def pbc(self) -> torch.Tensor:
"""A getter for pbc that tells type checkers it's always defined."""
return self.pbc
def charge(self) -> torch.Tensor: ... # noqa: D102
@property
def spin(self) -> torch.Tensor: ... # noqa: D102

_atom_attributes: ClassVar[set[str]] = {
"positions",
Expand Down Expand Up @@ -183,6 +182,16 @@ def __post_init__(self) -> None: # noqa: C901
if len(set(devices.values())) > 1:
raise ValueError("All tensors must be on the same device")

@classmethod
def _get_all_attributes(cls) -> set[str]:
"""Get all attributes of the SimState."""
return (
cls._atom_attributes
| cls._system_attributes
| cls._global_attributes
| {"_constraints"}
)

@property
def wrap_positions(self) -> torch.Tensor:
"""Atomic positions wrapped according to periodic boundary conditions if pbc=True,
Expand Down Expand Up @@ -224,13 +233,7 @@ def volume(self) -> torch.Tensor:
@property
def attributes(self) -> dict[str, torch.Tensor]:
"""Get all public attributes of the state."""
return {
attr: getattr(self, attr)
for attr in self._atom_attributes
| self._system_attributes
| self._global_attributes
| {"_constraints"}
}
return {attr: getattr(self, attr) for attr in self._get_all_attributes()}

@property
def column_vector_cell(self) -> torch.Tensor:
Expand Down Expand Up @@ -369,9 +372,10 @@ def clone(self) -> Self:
def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self:
"""Create a new state from an existing state with additional attributes.

This method copies all attributes from the source state and adds any additional
attributes needed for the target state class. It's useful for converting between
different state types (e.g., SimState to MDState).
This method copies attributes from the source state that are valid for the
target state class, and adds any additional attributes needed. It supports
upcasting (SimState -> MDState), downcasting (MDState -> SimState), and
cross-casting (MDState -> OptimState) between state types.

Args:
state: Source state to copy base attributes from
Expand All @@ -389,13 +393,13 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self:
... momenta=torch.zeros_like(sim_state.positions),
... )
"""
# Copy all attributes from the source state
attrs = {}
for attr_name, attr_value in state.attributes.items():
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.clone()
else:
attrs[attr_name] = copy.deepcopy(attr_value)
if attr_name in cls._get_all_attributes():
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.clone()
else:
attrs[attr_name] = copy.deepcopy(attr_value)

# Add/override with additional attributes
attrs.update(additional_attrs)
Expand Down