diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7882bd7cb..d138c4d6e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -69,6 +69,7 @@ jobs: - { name: mattersim, test_path: "tests/models/test_mattersim.py" } - { name: metatomic, test_path: "tests/models/test_metatomic.py" } - { name: nequip, test_path: "tests/models/test_nequip_framework.py" } + - { name: nequix, test_path: "tests/models/test_nequix.py" } - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } exclude: diff --git a/README.md b/README.md index 9ed2add13..ea4abcb7e 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ era. By rewriting the core primitives of atomistic simulation in Pytorch, it all orders of magnitude acceleration of popular machine learning potentials. * Automatic batching and GPU memory management allowing significant simulation speedup -* Support for MACE, Fairchem, SevenNet, ORB, MatterSim, graph-pes, and metatomic MLIP models +* Support for MACE, Fairchem, SevenNet, ORB, MatterSim, graph-pes, metatomic, and Nequix MLIP models * Support for classical lennard jones, morse, and soft-sphere potentials * Molecular dynamics integration schemes like NVE, NVT Langevin, and NPT Langevin * Relaxation of atomic positions and cell with gradient descent and FIRE diff --git a/pyproject.toml b/pyproject.toml index 4322706b1..b78516966 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ orb = ["orb-models>=0.6.0"] sevenn = ["sevenn[torchsim]>=0.12.1"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] nequip = ["nequip>=0.17.0"] +nequix = ["nequix[torch-sim]>=0.4.5"] fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] docs = [ "autodoc_pydantic==2.2.0", @@ -167,6 +168,10 @@ conflicts = [ { extra = "graphpes" }, { extra = "nequip" }, ], + [ + { extra = "graphpes" }, + { extra = "nequix" }, + ], [ { extra = "graphpes" }, { extra = "sevenn" }, @@ -179,6 +184,10 @@ conflicts = [ { extra = "mace" }, { extra = "nequip" }, ], + [ + { extra = "mace" }, + { extra = "nequix" }, + ], [ { extra = "mace" }, { extra = "sevenn" }, diff --git a/tests/models/test_nequix.py b/tests/models/test_nequix.py new file mode 100644 index 000000000..49b004e00 --- /dev/null +++ b/tests/models/test_nequix.py @@ -0,0 +1,54 @@ +import traceback + +import pytest + +from tests.conftest import DEVICE, DTYPE +from tests.models.conftest import ( + make_model_calculator_consistency_test, + make_validate_model_outputs_test, +) +from torch_sim.testing import SIMSTATE_BULK_GENERATORS + + +try: + from nequix.calculator import NequixCalculator + + from torch_sim.models.nequix import NequixModel +except (ImportError, ModuleNotFoundError): + pytest.skip( + f"nequix not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + allow_module_level=True, + ) + + +@pytest.fixture(scope="session") +def nequix_model() -> NequixModel: + return NequixModel("nequix-mp-1", device=DEVICE, dtype=DTYPE, use_kernel=False) + + +@pytest.fixture(scope="session") +def nequix_calculator() -> NequixCalculator: + return NequixCalculator( + "nequix-mp-1", + device=DEVICE, + backend="torch", + use_compile=False, + use_kernel=False, + ) + + +test_nequix_consistency = make_model_calculator_consistency_test( + test_name="nequix", + model_fixture_name="nequix_model", + calculator_fixture_name="nequix_calculator", + sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), + force_atol=5e-5, + dtype=DTYPE, + device=DEVICE, +) + +test_nequix_model_outputs = make_validate_model_outputs_test( + model_fixture_name="nequix_model", + dtype=DTYPE, + device=DEVICE, +) diff --git a/torch_sim/models/nequix.py b/torch_sim/models/nequix.py new file mode 100644 index 000000000..38f5cbdeb --- /dev/null +++ b/torch_sim/models/nequix.py @@ -0,0 +1,43 @@ +"""Wrapper for Nequix models in TorchSim. + +This module re-exports the nequix model's torch-sim integration for convenient +importing. The actual implementation is maintained in the nequix package. + +References: + - nequix Package: https://github.com/atomicarchitects/nequix + +""" + +import traceback +import warnings +from typing import Any, Self + + +try: + from nequix.torch_sim import NequixTorchSimModel + + # Re-export with backward-compatible name + class NequixModel(NequixTorchSimModel): + """Nequix model wrapper for torch-sim.""" + +except ImportError as exc: + _nequix_import_error = exc # capture before except block ends (exc is deleted) + warnings.warn(f"Nequix import failed: {traceback.format_exc()}", stacklevel=2) + + from torch_sim.models.interface import ModelInterface + + class NequixModel(ModelInterface): + """Nequix model wrapper for torch-sim. + + NOTE: This class is a placeholder when nequix is not installed. + It raises an ImportError if accessed. + """ + + def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: + """Dummy init for type checking.""" + raise err + + @classmethod + def from_compiled_model(cls, _path: Any, *_args: Any, **_kwargs: Any) -> Self: + """Dummy classmethod for type checking when nequix is not installed.""" + raise _nequix_import_error