diff --git a/pyproject.toml b/pyproject.toml index 27e3696c8..14bf4d0ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ test = [ "pytest>=8", ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] -mace = ["mace-torch>=0.3.12"] +mace = ["mace-torch>=0.3.14"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.7"] orb = ["orb-models>=0.5.2"] diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index f445a987d..bafcc8ced 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -246,3 +246,53 @@ def test_load_from_checkpoint_path() -> None: model_fixture_name="eqv2_uma_model_pbc", device=DEVICE, dtype=DTYPE ) ) + + +@pytest.mark.skipif( + get_token() is None, reason="Requires HuggingFace authentication for UMA model access" +) +@pytest.mark.parametrize( + ("charge", "spin"), + [ + (0.0, 0.0), # Neutral, no spin + (1.0, 1.0), # +1 charge, spin=1 (doublet) + (-1.0, 0.0), # -1 charge, no spin (singlet) + (0.0, 2.0), # Neutral, spin=2 (triplet) + ], +) +def test_fairchem_charge_spin(charge: float, spin: float) -> None: + """Test that FairChemModel correctly handles charge and spin from atoms.info.""" + # Create a water molecule + mol = molecule("H2O") + + # Set charge and spin in ASE atoms.info + mol.info["charge"] = charge + mol.info["spin"] = spin + + # Convert to SimState (should extract charge/spin) + state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE) + + # Verify charge/spin were extracted correctly + assert state.charge[0].item() == charge + assert state.spin[0].item() == spin + + # Create model with UMA omol task (supports charge/spin for molecules) + model = FairChemModel( + model=None, + model_name="uma-s-1", + task_name="omol", + cpu=DEVICE.type == "cpu", + ) + + # This should not raise an error + result = model(state) + + # Verify outputs exist + assert "energy" in result + assert result["energy"].shape == (1,) + assert "forces" in result + assert result["forces"].shape == (len(mol), 3) + + # Verify outputs are finite + assert torch.isfinite(result["energy"]).all() + assert torch.isfinite(result["forces"]).all() diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 0eabd99a1..50f89697b 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -16,7 +16,7 @@ try: from mace.calculators import MACECalculator - from mace.calculators.foundations_models import mace_mp, mace_off + from mace.calculators.foundations_models import mace_mp, mace_off, mace_omol from torch_sim.models.mace import MaceModel except (ImportError, ValueError): @@ -25,6 +25,7 @@ raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) raw_mace_off = mace_off(model=MaceUrls.mace_off_small, return_raw_model=True) +raw_mace_omol = mace_omol(model="extra_large", return_raw_model=True) DTYPE = torch.float64 @@ -137,3 +138,53 @@ def test_mace_urls_enum() -> None: for key in MaceUrls: assert key.value.startswith("https://github.com/ACEsuit/mace-") assert key.value.endswith((".model", ".model?raw=true")) + + +@pytest.mark.parametrize( + ("charge", "spin"), + [ + (0.0, 0.0), # Neutral, no spin + (1.0, 1.0), # +1 charge, spin=1 (doublet) + (-1.0, 0.0), # -1 charge, no spin (singlet) + (0.0, 2.0), # Neutral, spin=2 (triplet) + ], +) +def test_mace_charge_spin(benzene_atoms: Atoms, charge: float, spin: float) -> None: + """Test that MaceModel correctly handles charge and spin from atoms.info.""" + # Set charge and spin in ASE atoms.info + benzene_atoms.info["charge"] = charge + benzene_atoms.info["spin"] = spin + + # Convert to SimState (should extract charge/spin) + state = ts.io.atoms_to_state([benzene_atoms], DEVICE, DTYPE) + + # Verify charge/spin were extracted correctly + if charge != 0.0: + assert state.charge is not None + assert state.charge[0].item() == charge + else: + assert state.charge is None or state.charge[0].item() == 0.0 + + if spin != 0.0: + assert state.spin is not None + assert state.spin[0].item() == spin + else: + assert state.spin is None or state.spin[0].item() == 0.0 + + # Create model with MACE-OMOL (supports charge/spin for molecules) + model = MaceModel( + model=raw_mace_omol, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + ) + + # This should not raise an error + result = model.forward(state) + + # Verify outputs exist + assert "energy" in result + assert result["energy"].shape == (1,) + if model.compute_forces: + assert "forces" in result + assert result["forces"].shape == benzene_atoms.positions.shape diff --git a/tests/test_state.py b/tests/test_state.py index 1b2936681..426e3404b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -28,7 +28,7 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs) == {"cell"} + assert set(per_system_attrs) == {"cell", "charge", "spin"} global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) assert set(global_attrs) == {"pbc"} diff --git a/torch_sim/io.py b/torch_sim/io.py index eee9e4808..27be6b1c4 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -232,6 +232,13 @@ def atoms_to_state( if not all(np.all(np.equal(at.pbc, atoms_list[0].pbc)) for at in atoms_list[1:]): raise ValueError("All systems must have the same periodic boundary conditions") + charge = torch.tensor( + [at.info.get("charge", 0.0) for at in atoms_list], dtype=dtype, device=device + ) + spin = torch.tensor( + [at.info.get("spin", 0.0) for at in atoms_list], dtype=dtype, device=device + ) + return ts.SimState( positions=positions, masses=masses, @@ -239,6 +246,8 @@ def atoms_to_state( pbc=atoms_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, + charge=charge, + spin=spin, ) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 7600f7d8b..f4d1d82c7 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -223,6 +223,9 @@ def forward(self, state: ts.SimState | StateDict) -> dict: pbc=pbc if cell is not None else False, ) + atoms.info["charge"] = sim_state.charge[idx].item() + atoms.info["spin"] = sim_state.spin[idx].item() + # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) if self.task_name is None: atomic_data = AtomicData.from_ase(atoms) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 2d5fe6c32..8198ad98c 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -330,19 +330,24 @@ def forward( # noqa: C901 unit_shifts = torch.cat(unit_shifts_list, dim=0) shifts = torch.cat(shifts_list, dim=0) + # Build data dict for MACE model + data_dict = dict( + ptr=self.ptr, + node_attrs=self.node_attrs, + batch=sim_state.system_idx, + pbc=sim_state.pbc, + cell=sim_state.row_vector_cell, + positions=sim_state.positions, + edge_index=edge_index, + unit_shifts=unit_shifts, + shifts=shifts, + total_charge=sim_state.charge, + total_spin=sim_state.spin, + ) + # Get model output out = self.model( - dict( - ptr=self.ptr, - node_attrs=self.node_attrs, - batch=sim_state.system_idx, - pbc=sim_state.pbc, - cell=sim_state.row_vector_cell, - positions=sim_state.positions, - edge_index=edge_index, - unit_shifts=unit_shifts, - shifts=shifts, - ), + data_dict, compute_force=self.compute_forces, compute_stress=self.compute_stress, ) diff --git a/torch_sim/state.py b/torch_sim/state.py index 813354fe8..0b2d6ef86 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -84,6 +84,8 @@ class SimState: cell: torch.Tensor pbc: torch.Tensor | list[bool] | bool atomic_numbers: torch.Tensor + charge: torch.Tensor | None = field(default=None) + spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor | None = field(default=None) if TYPE_CHECKING: @@ -104,10 +106,10 @@ def pbc(self) -> torch.Tensor: "atomic_numbers", "system_idx", } - _system_attributes: ClassVar[set[str]] = {"cell"} + _system_attributes: ClassVar[set[str]] = {"cell", "charge", "spin"} _global_attributes: ClassVar[set[str]] = {"pbc"} - def __post_init__(self) -> None: + def __post_init__(self) -> None: # noqa: C901 """Initialize the SimState and validate the arguments.""" # Check that positions, masses and atomic numbers have compatible shapes shapes = [ @@ -136,6 +138,17 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(initial_system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.charge is None: + self.charge = torch.zeros( + self.n_systems, device=self.device, dtype=self.dtype + ) + elif self.charge.shape[0] != self.n_systems: + raise ValueError(f"Charge must have shape (n_systems={self.n_systems},)") + if self.spin is None: + self.spin = torch.zeros(self.n_systems, device=self.device, dtype=self.dtype) + elif self.spin.shape[0] != self.n_systems: + raise ValueError(f"Spin must have shape (n_systems={self.n_systems},)") + if self.cell.ndim != 3 and initial_system_idx is None: self.cell = self.cell.unsqueeze(0) @@ -419,7 +432,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: # exceptions exist because the type hint doesn't actually reflect the real type # (since we change their type in the post_init) - exceptions = {"system_idx"} + exceptions = {"system_idx", "charge", "spin"} type_hints = typing.get_type_hints(cls) for attr_name, attr_type_hint in type_hints.items():