diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index f9110ccf5..4d6bd752e 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -612,17 +612,12 @@ def npt_langevin_init( ) # Create the initial state - return NPTLangevinState( - positions=state.positions, + return NPTLangevinState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, alpha=alpha, b_tau=b_tau, reference_cell=reference_cell, @@ -630,7 +625,6 @@ def npt_langevin_init( cell_velocities=cell_velocities, cell_masses=cell_masses, cell_alpha=cell_alpha, - _constraints=state.constraints, ) @@ -1421,16 +1415,12 @@ def npt_nose_hoover_init( ) # Create initial state - return NPTNoseHooverState( - positions=state.positions, + return NPTNoseHooverState.from_state( + state, momenta=momenta, energy=energy, forces=forces, - masses=state.masses, atomic_numbers=atomic_numbers, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, reference_cell=reference_cell, cell_position=cell_position, cell_momentum=cell_momentum, @@ -1439,7 +1429,6 @@ def npt_nose_hoover_init( thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, - _constraints=state.constraints, ) @@ -2315,17 +2304,12 @@ def npt_crescale_init( ) # Create the initial state - return NPTCRescaleState( - positions=state.positions, + return NPTCRescaleState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, ) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index b4db4e6c9..4880cfac6 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -57,17 +57,11 @@ def nve_init( calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( - positions=state.positions, + return MDState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, - _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index d773a9221..61cb84723 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -118,17 +118,11 @@ def nvt_langevin_init( "momenta", calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( - positions=state.positions, + return MDState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, - _constraints=state.constraints, ) @@ -316,19 +310,14 @@ def nvt_nose_hoover_init( ) # n_atoms * n_dimensions # Initialize state - return NVTNoseHooverState( - positions=state.positions, + return NVTNoseHooverState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, atomic_numbers=atomic_numbers, - system_idx=state.system_idx, chain=chain_fns.initialize(dof_per_system, KE, kT), - _chain_fns=chain_fns, # Store the chain functions - _constraints=state.constraints, + _chain_fns=chain_fns, ) @@ -603,16 +592,11 @@ def nvt_vrescale_init( calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return NVTVRescaleState( - positions=state.positions, + return NVTVRescaleState.from_state( + state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - system_idx=state.system_idx, - atomic_numbers=state.atomic_numbers, ) diff --git a/torch_sim/io.py b/torch_sim/io.py index 27be6b1c4..fdd75893c 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -41,6 +41,7 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: Notes: - Output positions and cell will be in Å - Output masses will be in amu + - Charge and spin are preserved in atoms.info if present in the state """ try: from ase import Atoms @@ -55,6 +56,10 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: system_indices = state.system_idx.detach().cpu().numpy() pbc = state.pbc.detach().cpu().numpy() + # Extract charge and spin if available (per-system attributes) + charge = state.charge.detach().cpu().numpy() + spin = state.spin.detach().cpu().numpy() + atoms_list = [] for sys_idx in np.unique(system_indices): mask = system_indices == sys_idx @@ -68,6 +73,13 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: atoms = Atoms( symbols=symbols, positions=system_positions, cell=system_cell, pbc=pbc ) + + # Preserve charge and spin in atoms.info (as integers for FairChem compatibility) + if charge is not None: + atoms.info["charge"] = int(charge[sys_idx].item()) + if spin is not None: + atoms.info["spin"] = int(spin[sys_idx].item()) + atoms_list.append(atoms) return atoms_list diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 83d783e68..602cf7529 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -222,10 +222,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict: atoms.info["spin"] = sim_state.spin[idx].item() # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) + # r_data_keys must be passed for charge/spin to be read from atoms.info if self.task_name is None: - atomic_data = AtomicData.from_ase(atoms) + atomic_data = AtomicData.from_ase(atoms, r_data_keys=["charge", "spin"]) else: - atomic_data = AtomicData.from_ase(atoms, task_name=self.task_name) + atomic_data = AtomicData.from_ase( + atoms, task_name=self.task_name, r_data_keys=["charge", "spin"] + ) atomic_data_list.append(atomic_data) # Create batch for efficient inference diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index e69a4955f..763cb646c 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -72,24 +72,12 @@ def fire_init( forces = model_output["forces"] stress = model_output.get("stress") - # Common state arguments - common_args = { - # Copy SimState attributes - "positions": state.positions.clone(), - "masses": state.masses.clone(), - "cell": state.cell.clone(), - "atomic_numbers": state.atomic_numbers.clone(), - "system_idx": state.system_idx.clone(), - "_constraints": state.constraints, - "pbc": state.pbc, - "charge": state.charge.clone(), - "spin": state.spin.clone(), - # Optimization state + # FIRE-specific additional attributes + fire_attrs = { "forces": forces, "energy": energy, "stress": stress, "velocities": torch.full(state.positions.shape, torch.nan, **tensor_args), - # FIRE parameters "dt": torch.full((n_systems,), dt_start, **tensor_args), "alpha": torch.full((n_systems,), alpha_start, **tensor_args), "n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), @@ -97,9 +85,9 @@ def fire_init( if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - common_args["reference_cell"] = state.cell.clone() - common_args["cell_filter"] = cell_filter_funcs - cell_state = CellFireState(**common_args) + fire_attrs["reference_cell"] = state.cell.clone() + fire_attrs["cell_filter"] = cell_filter_funcs + cell_state = CellFireState.from_state(state, **fire_attrs) # Initialize cell-specific attributes init_fn(cell_state, model, **filter_kwargs) @@ -111,7 +99,7 @@ def fire_init( return cell_state # Create regular FireState without cell optimization - return FireState(**common_args) + return FireState.from_state(state, **fire_attrs) def fire_step( diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 283fe7c80..23a51a0ed 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -50,34 +50,25 @@ def gradient_descent_init( forces = model_output["forces"] stress = model_output.get("stress") - # Common state arguments - common_args = { - "positions": state.positions, + # Optimizer-specific additional attributes + optim_attrs = { "forces": forces, "energy": energy, "stress": stress, - "masses": state.masses, - "cell": state.cell, - "pbc": state.pbc, - "atomic_numbers": state.atomic_numbers, - "system_idx": state.system_idx, - "_constraints": state.constraints, - "charge": state.charge, - "spin": state.spin, } if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - common_args["reference_cell"] = state.cell.clone() - common_args["cell_filter"] = cell_filter_funcs - cell_state = CellOptimState(**common_args) + optim_attrs["reference_cell"] = state.cell.clone() + optim_attrs["cell_filter"] = cell_filter_funcs + cell_state = CellOptimState.from_state(state, **optim_attrs) # Initialize cell-specific attributes init_fn(cell_state, model, **filter_kwargs) return cell_state # Create regular OptimState without cell optimization - return OptimState(**common_args) + return OptimState.from_state(state, **optim_attrs) def gradient_descent_step( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 21d10be8e..ff445bf9e 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -762,13 +762,8 @@ class StaticState(SimState): ) model_outputs = model(sub_state) - static_state = StaticState( - positions=sub_state.positions, - masses=sub_state.masses, - cell=sub_state.cell, - pbc=sub_state.pbc, - atomic_numbers=sub_state.atomic_numbers, - system_idx=sub_state.system_idx, + static_state = StaticState.from_state( + state=sub_state, energy=model_outputs["energy"], forces=( model_outputs["forces"] diff --git a/torch_sim/state.py b/torch_sim/state.py index e97101d73..ec2c589de 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -96,14 +96,13 @@ class SimState: if TYPE_CHECKING: @property - def system_idx(self) -> torch.Tensor: - """A getter for system_idx that tells type checkers it's always defined.""" - return self.system_idx - + def system_idx(self) -> torch.Tensor: ... # noqa: D102 + @property + def pbc(self) -> torch.Tensor: ... # noqa: D102 @property - def pbc(self) -> torch.Tensor: - """A getter for pbc that tells type checkers it's always defined.""" - return self.pbc + def charge(self) -> torch.Tensor: ... # noqa: D102 + @property + def spin(self) -> torch.Tensor: ... # noqa: D102 _atom_attributes: ClassVar[set[str]] = { "positions", @@ -183,6 +182,16 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") + @classmethod + def _get_all_attributes(cls) -> set[str]: + """Get all attributes of the SimState.""" + return ( + cls._atom_attributes + | cls._system_attributes + | cls._global_attributes + | {"_constraints"} + ) + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, @@ -224,13 +233,7 @@ def volume(self) -> torch.Tensor: @property def attributes(self) -> dict[str, torch.Tensor]: """Get all public attributes of the state.""" - return { - attr: getattr(self, attr) - for attr in self._atom_attributes - | self._system_attributes - | self._global_attributes - | {"_constraints"} - } + return {attr: getattr(self, attr) for attr in self._get_all_attributes()} @property def column_vector_cell(self) -> torch.Tensor: @@ -369,9 +372,10 @@ def clone(self) -> Self: def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: """Create a new state from an existing state with additional attributes. - This method copies all attributes from the source state and adds any additional - attributes needed for the target state class. It's useful for converting between - different state types (e.g., SimState to MDState). + This method copies attributes from the source state that are valid for the + target state class, and adds any additional attributes needed. It supports + upcasting (SimState -> MDState), downcasting (MDState -> SimState), and + cross-casting (MDState -> OptimState) between state types. Args: state: Source state to copy base attributes from @@ -389,13 +393,13 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: ... momenta=torch.zeros_like(sim_state.positions), ... ) """ - # Copy all attributes from the source state attrs = {} for attr_name, attr_value in state.attributes.items(): - if isinstance(attr_value, torch.Tensor): - attrs[attr_name] = attr_value.clone() - else: - attrs[attr_name] = copy.deepcopy(attr_value) + if attr_name in cls._get_all_attributes(): + if isinstance(attr_value, torch.Tensor): + attrs[attr_name] = attr_value.clone() + else: + attrs[attr_name] = copy.deepcopy(attr_value) # Add/override with additional attributes attrs.update(additional_attrs)