Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ jobs:
- { name: mattersim, test_path: "tests/models/test_mattersim.py" }
- { name: metatomic, test_path: "tests/models/test_metatomic.py" }
- { name: nequip, test_path: "tests/models/test_nequip_framework.py" }
- { name: nequix, test_path: "tests/models/test_nequix.py" }
- { name: orb, test_path: "tests/models/test_orb.py" }
- { name: sevenn, test_path: "tests/models/test_sevennet.py" }
exclude:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ era. By rewriting the core primitives of atomistic simulation in Pytorch, it all
orders of magnitude acceleration of popular machine learning potentials.

* Automatic batching and GPU memory management allowing significant simulation speedup
* Support for MACE, Fairchem, SevenNet, ORB, MatterSim, graph-pes, and metatomic MLIP models
* Support for MACE, Fairchem, SevenNet, ORB, MatterSim, graph-pes, metatomic, and Nequix MLIP models
* Support for classical lennard jones, morse, and soft-sphere potentials
* Molecular dynamics integration schemes like NVE, NVT Langevin, and NPT Langevin
* Relaxation of atomic positions and cell with gradient descent and FIRE
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ orb = ["orb-models>=0.6.0"]
sevenn = ["sevenn[torchsim]>=0.12.1"]
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
nequip = ["nequip>=0.17.0"]
nequix = ["nequix[torch-sim]>=0.4.5"]
fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"]
docs = [
"autodoc_pydantic==2.2.0",
Expand Down Expand Up @@ -167,6 +168,10 @@ conflicts = [
{ extra = "graphpes" },
{ extra = "nequip" },
],
[
{ extra = "graphpes" },
{ extra = "nequix" },
],
[
{ extra = "graphpes" },
{ extra = "sevenn" },
Expand All @@ -179,6 +184,10 @@ conflicts = [
{ extra = "mace" },
{ extra = "nequip" },
],
[
{ extra = "mace" },
{ extra = "nequix" },
],
[
{ extra = "mace" },
{ extra = "sevenn" },
Expand Down
54 changes: 54 additions & 0 deletions tests/models/test_nequix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import traceback

import pytest

from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import (
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)
from torch_sim.testing import SIMSTATE_BULK_GENERATORS


try:
from nequix.calculator import NequixCalculator

from torch_sim.models.nequix import NequixModel
except (ImportError, ModuleNotFoundError):
pytest.skip(
f"nequix not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments]
allow_module_level=True,
)


@pytest.fixture(scope="session")
def nequix_model() -> NequixModel:
return NequixModel("nequix-mp-1", device=DEVICE, dtype=DTYPE, use_kernel=False)


@pytest.fixture(scope="session")
def nequix_calculator() -> NequixCalculator:
return NequixCalculator(
"nequix-mp-1",
device=DEVICE,
backend="torch",
use_compile=False,
use_kernel=False,
)


test_nequix_consistency = make_model_calculator_consistency_test(
test_name="nequix",
model_fixture_name="nequix_model",
calculator_fixture_name="nequix_calculator",
sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()),
force_atol=5e-5,
dtype=DTYPE,
device=DEVICE,
)

test_nequix_model_outputs = make_validate_model_outputs_test(
model_fixture_name="nequix_model",
dtype=DTYPE,
device=DEVICE,
)
43 changes: 43 additions & 0 deletions torch_sim/models/nequix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Wrapper for Nequix models in TorchSim.

This module re-exports the nequix model's torch-sim integration for convenient
importing. The actual implementation is maintained in the nequix package.

References:
- nequix Package: https://github.com/atomicarchitects/nequix

"""

import traceback
import warnings
from typing import Any, Self


try:
from nequix.torch_sim import NequixTorchSimModel

# Re-export with backward-compatible name
class NequixModel(NequixTorchSimModel):
"""Nequix model wrapper for torch-sim."""

except ImportError as exc:
_nequix_import_error = exc # capture before except block ends (exc is deleted)
warnings.warn(f"Nequix import failed: {traceback.format_exc()}", stacklevel=2)

from torch_sim.models.interface import ModelInterface

class NequixModel(ModelInterface):
"""Nequix model wrapper for torch-sim.

NOTE: This class is a placeholder when nequix is not installed.
It raises an ImportError if accessed.
"""

def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
"""Dummy init for type checking."""
raise err

@classmethod
def from_compiled_model(cls, _path: Any, *_args: Any, **_kwargs: Any) -> Self:
"""Dummy classmethod for type checking when nequix is not installed."""
raise _nequix_import_error
Loading