diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f47f6238e..2ed7149ac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ on: workflow_call: concurrency: - group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }} + group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: @@ -37,20 +37,11 @@ jobs: - name: Install torch_sim run: | uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system - uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system + uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system - name: Run core tests run: | - pytest --cov=torch_sim --cov-report=xml \ - --ignore=tests/test_elastic.py \ - --ignore=tests/models/test_fairchem.py \ - --ignore=tests/models/test_graphpes.py \ - --ignore=tests/models/test_mace.py \ - --ignore=tests/models/test_orb.py \ - --ignore=tests/models/test_sevennet.py \ - --ignore=tests/models/test_mattersim.py \ - --ignore=tests/models/test_metatomic.py \ - --ignore=tests/test_optimizers_vs_ase.py \ + pytest --cov=torch_sim --cov-report=xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 @@ -68,6 +59,7 @@ jobs: - { python: "3.12", resolution: lowest-direct } model: - { name: fairchem, test_path: "tests/models/test_fairchem.py" } + - { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } - { name: graphpes, test_path: "tests/models/test_graphpes.py" } - { name: mace, test_path: "tests/models/test_mace.py" } - { name: mace, test_path: "tests/test_elastic.py" } @@ -82,6 +74,14 @@ jobs: - name: Check out repo uses: actions/checkout@v4 + - name: Check out fairchem repository + if: ${{ matrix.model.name == 'fairchem-legacy' }} + uses: actions/checkout@v4 + with: + repository: FAIR-Chem/fairchem + path: fairchem-repo + ref: fairchem_core-1.10.0 + - name: Set up Python uses: actions/setup-python@v5 with: @@ -90,29 +90,32 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v6 - - name: Install fairchem and dependencies - if: ${{ matrix.model.name == 'fairchem' }} - env: - HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + - name: Install legacy fairchem repository and dependencies + if: ${{ matrix.model.name == 'fairchem-legacy' }} run: | - uv pip install "torch>=2.6" --index-url https://download.pytorch.org/whl/cpu --system - uv pip install "fairchem-core>=2.2.0" --system - uv pip install "huggingface_hub[cli]" --system - uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system + if [ -f fairchem-repo/packages/requirements.txt ]; then + uv pip install -r fairchem-repo/packages/requirements.txt --system + fi + if [ -f fairchem-repo/packages/requirements-optional.txt ]; then + uv pip install -r fairchem-repo/packages/requirements-optional.txt --system + fi + uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system + uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system - name: Install torch_sim with model dependencies - if: ${{ matrix.model.name != 'fairchem' }} + if: ${{ matrix.model.name != 'fairchem-legacy' }} run: | - uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system + uv pip install -e ".[test,${{ matrix.model.name }}]" --resolution=${{ matrix.version.resolution }} --system - name: Run ${{ matrix.model.test_path }} tests env: HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} run: | - if [ "${{ matrix.model.name }}" == "fairchem" ]; then + if [[ "${{ matrix.model.name }}" == *"fairchem"* ]]; then + uv pip install "huggingface_hub[cli]" --system huggingface-cli login --token "$HF_TOKEN" fi - pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} + pytest -vv -ra -rs --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 diff --git a/pyproject.toml b/pyproject.toml index 8667285de..eed6a52c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ metatomic = ["metatomic-torch>=0.1.1,<0.2", "metatrain[pet]==2025.7"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.12"] +fairchem = ["fairchem-core>=2.2.0"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", @@ -142,6 +143,14 @@ testpaths = ["tests"] # make these dependencies mutually exclusive since they use incompatible e3nn versions # see https://docs.astral.sh/uv/concepts/projects/config/#conflicting-dependencies for more details conflicts = [ + [ + { extra = "fairchem" }, + { extra = "graphpes" }, + ], + [ + { extra = "fairchem" }, + { extra = "mace" }, + ], [ { extra = "graphpes" }, { extra = "mattersim" }, diff --git a/tests/conftest.py b/tests/conftest.py index 9a19696e7..1ada60dea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import traceback from typing import TYPE_CHECKING, Any import numpy as np @@ -48,7 +49,12 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: @pytest.fixture def ase_mace_mpa() -> "MACECalculator": """Provides an ASE MACECalculator instance using mace_mp.""" - from mace.calculators.foundations_models import mace_mp + try: + from mace.calculators.foundations_models import mace_mp + except (ImportError, ModuleNotFoundError): + pytest.skip( + f"MACE not installed: {traceback.format_exc()}", allow_module_level=True + ) # Ensure dtype matches the one used in the torchsim fixture (float64) return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") @@ -57,7 +63,12 @@ def ase_mace_mpa() -> "MACECalculator": @pytest.fixture def torchsim_mace_mpa() -> MaceModel: """Provides a MACE MP model instance for the optimizer tests.""" - from mace.calculators.foundations_models import mace_mp + try: + from mace.calculators.foundations_models import mace_mp + except (ImportError, ModuleNotFoundError): + pytest.skip( + f"MACE not installed: {traceback.format_exc()}", allow_module_level=True + ) # Use float64 for potentially higher precision needed in optimization dtype = getattr(torch, dtype_str := "float64") diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index f73b6a975..3c6034c25 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,3 +1,5 @@ +import traceback + import pytest import torch @@ -14,7 +16,9 @@ from torch_sim.models.fairchem import FairChemModel except ImportError: - pytest.skip("FairChem not installed", allow_module_level=True) + pytest.skip( + f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True + ) @pytest.fixture diff --git a/tests/models/test_fairchem_legacy.py b/tests/models/test_fairchem_legacy.py new file mode 100644 index 000000000..3073c424e --- /dev/null +++ b/tests/models/test_fairchem_legacy.py @@ -0,0 +1,102 @@ +import os +import traceback + +import pytest +import torch + +from tests.models.conftest import ( + consistency_test_simstate_fixtures, + make_model_calculator_consistency_test, + make_validate_model_outputs_test, +) + + +try: + from fairchem.core import OCPCalculator + from fairchem.core.models.model_registry import model_name_to_local_file + from huggingface_hub.utils._auth import get_token + + from torch_sim.models.fairchem_legacy import FairChemV1Model + +except ImportError: + pytest.skip( + f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True + ) + + +@pytest.fixture(scope="session") +def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: + tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") + model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" + return model_name_to_local_file(model_name, local_cache=str(tmp_path)) + + +@pytest.fixture +def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemV1Model: + cpu = device.type == "cpu" + return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=True) + + +@pytest.fixture +def eqv2_oc20_model_non_pbc( + model_path_oc20: str, device: torch.device +) -> FairChemV1Model: + cpu = device.type == "cpu" + return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=False) + + +if get_token(): + + @pytest.fixture(scope="session") + def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: + tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") + model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" + return model_name_to_local_file(model_name, local_cache=str(tmp_path)) + + @pytest.fixture + def eqv2_omat24_model_pbc( + model_path_omat24: str, device: torch.device + ) -> FairChemV1Model: + cpu = device.type == "cpu" + return FairChemV1Model(model=model_path_omat24, cpu=cpu, seed=0, pbc=True) + + +@pytest.fixture +def ocp_calculator(model_path_oc20: str) -> OCPCalculator: + return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0) + + +test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( + test_name="fairchem_ocp", + model_fixture_name="eqv2_oc20_model_pbc", + calculator_fixture_name="ocp_calculator", + sim_state_names=consistency_test_simstate_fixtures[:-1], + energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + energy_atol=5e-4, + force_rtol=5e-4, + force_atol=5e-4, + stress_rtol=5e-4, + stress_atol=5e-4, +) + +test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( + test_name="fairchem_non_pbc_benzene", + model_fixture_name="eqv2_oc20_model_non_pbc", + calculator_fixture_name="ocp_calculator", + sim_state_names=["benzene_sim_state"], + energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + energy_atol=5e-4, + force_rtol=5e-4, + force_atol=5e-4, + stress_rtol=5e-4, + stress_atol=5e-4, +) + + +# Skip this test due to issues with how the older models +# handled supercells (see related issue here: https://github.com/facebookresearch/fairchem/issues/428) + +test_fairchem_ocp_model_outputs = pytest.mark.skipif( + os.environ.get("HF_TOKEN") is None, + reason="Issues in graph construction of older models", +)(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc")) diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py index 1470d3bec..46799c764 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes.py @@ -1,3 +1,5 @@ +import traceback + import pytest import torch from ase.build import bulk, molecule @@ -16,7 +18,9 @@ from graph_pes.interfaces import mace_mp from graph_pes.models import LennardJones, SchNet, TensorNet, ZEmbeddingNequIP except ImportError: - pytest.skip("graph-pes not installed", allow_module_level=True) + pytest.skip( + f"graph-pes not installed: {traceback.format_exc()}", allow_module_level=True + ) @pytest.fixture diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 427ef0647..3820e3a05 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -1,3 +1,5 @@ +import traceback + import pytest import torch from ase.atoms import Atoms @@ -17,7 +19,7 @@ from torch_sim.models.mace import MaceModel except (ImportError, ValueError): - pytest.skip("MACE not installed", allow_module_level=True) + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) mace_model = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index a137ed788..e473de4f2 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -1,5 +1,7 @@ # codespell-ignore: convertor +import traceback + import ase.spacegroup import ase.units import pytest @@ -18,7 +20,9 @@ from torch_sim.models.mattersim import MatterSimModel except ImportError: - pytest.skip("mattersim not installed", allow_module_level=True) + pytest.skip( + f"mattersim not installed: {traceback.format_exc()}", allow_module_level=True + ) @pytest.fixture diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index f467e4e75..b76ce7caa 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -1,3 +1,5 @@ +import traceback + import pytest import torch @@ -14,7 +16,9 @@ from torch_sim.models.metatomic import MetatomicModel except ImportError: - pytest.skip("metatomic not installed", allow_module_level=True) + pytest.skip( + f"metatomic not installed: {traceback.format_exc()}", allow_module_level=True + ) @pytest.fixture diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 5c75d4bdc..4c0054db2 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -1,3 +1,5 @@ +import traceback + import pytest import torch @@ -14,7 +16,7 @@ from torch_sim.models.orb import OrbModel except ImportError: - pytest.skip("ORB not installed", allow_module_level=True) + pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) @pytest.fixture diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 25bd310a9..0faa91ec7 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -1,3 +1,5 @@ +import traceback + import pytest import torch @@ -16,7 +18,9 @@ from torch_sim.models.sevennet import SevenNetModel except ImportError: - pytest.skip("sevenn not installed", allow_module_level=True) + pytest.skip( + f"sevenn not installed: {traceback.format_exc()}", allow_module_level=True + ) @pytest.fixture diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 91a063da0..46d75031a 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -1,3 +1,5 @@ +import traceback + import pytest import torch @@ -20,7 +22,7 @@ from torch_sim.models.mace import MaceModel except ImportError: - pytest.skip("MACE not installed", allow_module_level=True) + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) def test_get_strain_zero_deformation(cu_sim_state: ts.SimState) -> None: diff --git a/torch_sim/models/fairchem_legacy.py b/torch_sim/models/fairchem_legacy.py new file mode 100644 index 000000000..d50a9f09f --- /dev/null +++ b/torch_sim/models/fairchem_legacy.py @@ -0,0 +1,410 @@ +"""Wrapper for Legacy FairChem ecosystem models in TorchSim. + +This module provides a TorchSim wrapper of the FairChem models for computing +energies, forces, and stresses of atomistic systems. It serves as a wrapper around +the FairChem library, integrating it with the torch_sim framework to enable seamless +simulation of atomistic systems with machine learning potentials. + +The FairChemV1Model class adapts FairChem models to the ModelInterface protocol, +allowing them to be used within the broader torch_sim simulation framework. + +Notes: + This implementation requires FairChem < 2.0.0 to be installed and accessible. + It supports various model configurations through configuration files or + pretrained model checkpoints. +""" + +# ruff: noqa: T201 + +from __future__ import annotations + +import copy +import traceback +import typing +import warnings +from types import MappingProxyType +from typing import Any + +import torch + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface + + +def _validate_fairchem_version() -> None: + """Check for a compatible legacy FairChem version.""" + from importlib.metadata import version + + from packaging.version import parse + + fairchem_version = parse(version("fairchem-core")) + if fairchem_version >= parse("2.0.0"): + raise ImportError("FairChem v1.10.0 or lower is required") + + +try: + _validate_fairchem_version() + from fairchem.core.common.registry import registry + from fairchem.core.common.utils import ( + load_config, + setup_imports, + setup_logging, + update_config, + ) + from fairchem.core.models.model_registry import model_name_to_local_file + from torch_geometric.data import Batch, Data + +except ImportError as exc: + warnings.warn(f"FairChem import failed: {traceback.format_exc()}", stacklevel=2) + + class FairChemV1Model(ModelInterface): + """FairChem model wrapper for torch_sim. + + This class is a placeholder for the FairChemV1Model class. + It raises an ImportError if FairChem is not installed. + """ + + def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: + """Dummy init for type checking.""" + raise err + + +if typing.TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + from torch_sim.typing import StateDict + +_DTYPE_DICT = { + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", +} + + +class FairChemV1Model(ModelInterface): + """Computes atomistic energies, forces and stresses using a FairChem model. + + This class wraps a FairChem model to compute energies, forces, and stresses for + atomistic systems. It handles model initialization, checkpoint loading, and + provides a forward pass that accepts a SimState object and returns model + predictions. + + The model can be initialized either with a configuration file or a pretrained + checkpoint. It supports various model architectures and configurations supported by + FairChem. + + Attributes: + neighbor_list_fn (Callable | None): Function to compute neighbor lists + config (dict): Complete model configuration dictionary + trainer: FairChem trainer object that contains the model + data_object (Batch): Data object containing system information + implemented_properties (list): Model outputs the model can compute + pbc (bool): Whether periodic boundary conditions are used + _dtype (torch.dtype): Data type used for computation + _compute_stress (bool): Whether to compute stress tensor + _compute_forces (bool): Whether to compute forces + _device (torch.device): Device where computation is performed + _reshaped_props (dict): Properties that need reshaping after computation + + Examples: + >>> model = FairChemV1Model(model="path/to/checkpoint.pt", compute_stress=True) + >>> results = model(state) + """ + + _reshaped_props = MappingProxyType( + {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} + ) + + def __init__( # noqa: C901, PLR0915 + self, + model: str | Path | None, + neighbor_list_fn: Callable | None = None, + *, # force remaining arguments to be keyword-only + config_yml: str | None = None, + model_name: str | None = None, + local_cache: str | None = None, + trainer: str | None = None, + cpu: bool = False, + seed: int | None = None, + dtype: torch.dtype | None = None, + compute_stress: bool = False, + pbc: bool = True, + disable_amp: bool = True, + ) -> None: + """Initialize the FairChemV1Model with specified configuration. + + Loads a FairChem model from either a checkpoint path or a configuration file. + Sets up the model parameters, trainer, and configuration for subsequent use + in energy and force calculations. + + Args: + model (str | Path | None): Path to model checkpoint file + neighbor_list_fn (Callable | None): Function to compute neighbor lists + (not currently supported) + config_yml (str | None): Path to configuration YAML file + model_name (str | None): Name of pretrained model to load + local_cache (str | None): Path to local model cache directory + trainer (str | None): Name of trainer class to use + cpu (bool): Whether to use CPU instead of GPU for computation + seed (int | None): Random seed for reproducibility + dtype (torch.dtype | None): Data type to use for computation + compute_stress (bool): Whether to compute stress tensor + pbc (bool): Whether to use periodic boundary conditions + disable_amp (bool): Whether to disable AMP + Raises: + RuntimeError: If both model_name and model are specified + NotImplementedError: If local_cache is not set when model_name is used + NotImplementedError: If custom neighbor list function is provided + ValueError: If stress computation is requested but not supported by model + + Notes: + Either config_yml or model must be provided. The model loads configuration + from the checkpoint if config_yml is not specified. + """ + setup_imports() + setup_logging() + super().__init__() + + self._dtype = dtype or torch.float32 + self._compute_stress = compute_stress + self._compute_forces = True + self._memory_scales_with = "n_atoms" + self.pbc = pbc + + if model_name is not None: + if model is not None: + raise RuntimeError( + "model_name and checkpoint_path were both specified, " + "please use only one at a time" + ) + if local_cache is None: + raise NotImplementedError( + "Local cache must be set when specifying a model name" + ) + model = model_name_to_local_file( + model_name=model_name, local_cache=local_cache + ) + + # Either the config path or the checkpoint path needs to be provided + if not config_yml and model is None: + raise ValueError("Either config_yml or model must be provided") + + checkpoint = None + if config_yml is not None: + if isinstance(config_yml, str): + config, duplicates_warning, duplicates_error = load_config(config_yml) + if len(duplicates_warning) > 0: + print( + "Overwritten config parameters from included configs " + f"(non-included parameters take precedence): {duplicates_warning}" + ) + if len(duplicates_error) > 0: + raise ValueError( + "Conflicting (duplicate) parameters in simultaneously " + f"included configs: {duplicates_error}" + ) + else: + config = config_yml + + # Only keeps the train data that might have normalizer values + if isinstance(config["dataset"], list): + config["dataset"] = config["dataset"][0] + elif isinstance(config["dataset"], dict): + config["dataset"] = config["dataset"].get("train", None) + else: + # Loads the config from the checkpoint directly (always on CPU). + checkpoint = torch.load(model, map_location=torch.device("cpu")) + config = checkpoint["config"] + + if trainer is not None: + config["trainer"] = trainer + else: + config["trainer"] = config.get("trainer", "ocp") + + if "model_attributes" in config: + config["model_attributes"]["name"] = config.pop("model") + config["model"] = config["model_attributes"] + + self.neighbor_list_fn = neighbor_list_fn + + if neighbor_list_fn is None: + # Calculate the edge indices on the fly + config["model"]["otf_graph"] = True + else: + raise NotImplementedError( + "Custom neighbor list is not supported for FairChemV1Model." + ) + + if "backbone" in config["model"]: + config["model"]["backbone"]["use_pbc"] = pbc + config["model"]["backbone"]["use_pbc_single"] = False + if dtype is not None: + try: + config["model"]["backbone"].update({"dtype": _DTYPE_DICT[dtype]}) + for key in config["model"]["heads"]: + config["model"]["heads"][key].update( + {"dtype": _DTYPE_DICT[dtype]} + ) + except KeyError: + print( + "WARNING: dtype not found in backbone, using default model dtype" + ) + else: + config["model"]["use_pbc"] = pbc + config["model"]["use_pbc_single"] = False + if dtype is not None: + try: + config["model"].update({"dtype": _DTYPE_DICT[dtype]}) + except KeyError: + print( + "WARNING: dtype not found in backbone, using default model dtype" + ) + + ### backwards compatibility with OCP v<2.0 + config = update_config(config) + + self.config = copy.deepcopy(config) + self.config["checkpoint"] = str(model) + del config["dataset"]["src"] + + self.trainer = registry.get_trainer_class(config["trainer"])( + task=config.get("task", {}), + model=config["model"], + dataset=[config["dataset"]], + outputs=config["outputs"], + loss_functions=config["loss_functions"], + evaluation_metrics=config["evaluation_metrics"], + optimizer=config["optim"], + identifier="", + slurm=config.get("slurm", {}), + local_rank=config.get("local_rank", 0), + is_debug=config.get("is_debug", True), + cpu=cpu, + amp=False if dtype is not None else config.get("amp", False), + inference_only=True, + ) + + if dtype is not None: + # Convert model parameters to specified dtype + self.trainer.model = self.trainer.model.to(dtype=self.dtype) + + if model is not None: + self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) + + seed = seed if seed is not None else self.trainer.config["cmd"]["seed"] + if seed is None: + print( + "No seed has been set in model checkpoint or OCPCalculator! Results may " + "not be reproducible on re-run" + ) + else: + self.trainer.set_seed(seed) + + if disable_amp: + self.trainer.scaler = None + + self.implemented_properties = list(self.config["outputs"]) + + self._device = self.trainer.device + + stress_output = "stress" in self.implemented_properties + if not stress_output and compute_stress: + raise NotImplementedError("Stress output not implemented for this model") + + def load_checkpoint( + self, checkpoint_path: str, checkpoint: dict | None = None + ) -> None: + """Load an existing trained model checkpoint. + + Loads model parameters from a checkpoint file or dictionary, + setting the model to inference mode. + + Args: + checkpoint_path (str): Path to the trained model checkpoint file + checkpoint (dict | None): A pretrained checkpoint dictionary. If provided, + this dictionary is used instead of loading from checkpoint_path. + + Notes: + If loading fails, a message is printed but no exception is raised. + """ + try: + self.trainer.load_checkpoint(checkpoint_path, checkpoint, inference_only=True) + except NotImplementedError: + print("Unable to load checkpoint!") + + def forward(self, state: ts.SimState | StateDict) -> dict: + """Perform forward pass to compute energies, forces, and other properties. + + Takes a simulation state and computes the properties implemented by the model, + such as energy, forces, and stresses. + + Args: + state (SimState | StateDict): State object containing positions, cells, + atomic numbers, and other system information. If a dictionary is provided, + it will be converted to a SimState. + + Returns: + dict: Dictionary of model predictions, which may include: + - energy (torch.Tensor): Energy with shape [batch_size] + - forces (torch.Tensor): Forces with shape [n_atoms, 3] + - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3], + if compute_stress is True + + Notes: + The state is automatically transferred to the model's device if needed. + All output tensors are detached from the computation graph. + """ + if isinstance(state, dict): + state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + + if state.device != self._device: + state = state.to(self._device) + + if state.system_idx is None: + state.system_idx = torch.zeros(state.positions.shape[0], dtype=torch.int) + + if self.pbc != state.pbc: + raise ValueError( + "PBC mismatch between model and state. " + "For FairChemV1Model PBC needs to be defined in the model class." + ) + + natoms = torch.bincount(state.system_idx) + fixed = torch.zeros((state.system_idx.size(0), natoms.sum()), dtype=torch.int) + data_list = [] + for i, (n, c) in enumerate( + zip(natoms, torch.cumsum(natoms, dim=0), strict=False) + ): + data_list.append( + Data( + pos=state.positions[c - n : c].clone(), + cell=state.row_vector_cell[i, None].clone(), + atomic_numbers=state.atomic_numbers[c - n : c].clone(), + fixed=fixed[c - n : c].clone(), + natoms=n, + pbc=torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool), + ) + ) + self.data_object = Batch.from_data_list(data_list) + + if self.dtype is not None: + self.data_object.pos = self.data_object.pos.to(self.dtype) + self.data_object.cell = self.data_object.cell.to(self.dtype) + + predictions = self.trainer.predict( + self.data_object, per_image=False, disable_tqdm=True + ) + + results = {} + + for key in predictions: + _pred = predictions[key] + if key in self._reshaped_props: + _pred = _pred.reshape(self._reshaped_props.get(key)).squeeze() + results[key] = _pred.detach() + + results["energy"] = results["energy"].squeeze(dim=1) + if results.get("stress") is not None and len(results["stress"].shape) == 2: + results["stress"] = results["stress"].unsqueeze(dim=0) + return results diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 5ca2a629c..31d31d2d9 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -35,7 +35,7 @@ try: from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.tools import atomic_numbers_to_indices, utils -except ImportError as exc: +except (ImportError, ModuleNotFoundError) as exc: warnings.warn(f"MACE import failed: {traceback.format_exc()}", stacklevel=2) class MaceModel(ModelInterface):