Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
workflow_call:

concurrency:
group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }}
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
Expand Down Expand Up @@ -37,20 +37,11 @@ jobs:
- name: Install torch_sim
run: |
uv pip install "torch>2" --index-url https://download.pytorch.org/whl/cpu --system
uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system
uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system

- name: Run core tests
run: |
pytest --cov=torch_sim --cov-report=xml \
--ignore=tests/test_elastic.py \
--ignore=tests/models/test_fairchem.py \
--ignore=tests/models/test_graphpes.py \
--ignore=tests/models/test_mace.py \
--ignore=tests/models/test_orb.py \
--ignore=tests/models/test_sevennet.py \
--ignore=tests/models/test_mattersim.py \
--ignore=tests/models/test_metatomic.py \
--ignore=tests/test_optimizers_vs_ase.py \
pytest --cov=torch_sim --cov-report=xml

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
Expand All @@ -68,6 +59,7 @@ jobs:
- { python: "3.12", resolution: lowest-direct }
model:
- { name: fairchem, test_path: "tests/models/test_fairchem.py" }
- { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" }
- { name: graphpes, test_path: "tests/models/test_graphpes.py" }
- { name: mace, test_path: "tests/models/test_mace.py" }
- { name: mace, test_path: "tests/test_elastic.py" }
Expand All @@ -82,6 +74,14 @@ jobs:
- name: Check out repo
uses: actions/checkout@v4

- name: Check out fairchem repository
if: ${{ matrix.model.name == 'fairchem-legacy' }}
uses: actions/checkout@v4
with:
repository: FAIR-Chem/fairchem
path: fairchem-repo
ref: fairchem_core-1.10.0

- name: Set up Python
uses: actions/setup-python@v5
with:
Expand All @@ -90,29 +90,32 @@ jobs:
- name: Set up uv
uses: astral-sh/setup-uv@v6

- name: Install fairchem and dependencies
if: ${{ matrix.model.name == 'fairchem' }}
env:
HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
- name: Install legacy fairchem repository and dependencies
if: ${{ matrix.model.name == 'fairchem-legacy' }}
run: |
uv pip install "torch>=2.6" --index-url https://download.pytorch.org/whl/cpu --system
uv pip install "fairchem-core>=2.2.0" --system
uv pip install "huggingface_hub[cli]" --system
uv pip install -e .[test] --resolution=${{ matrix.version.resolution }} --system
if [ -f fairchem-repo/packages/requirements.txt ]; then
uv pip install -r fairchem-repo/packages/requirements.txt --system
fi
if [ -f fairchem-repo/packages/requirements-optional.txt ]; then
uv pip install -r fairchem-repo/packages/requirements-optional.txt --system
fi
uv pip install -e fairchem-repo/packages/fairchem-core[dev] --system
uv pip install -e ".[test]" --resolution=${{ matrix.version.resolution }} --system

- name: Install torch_sim with model dependencies
if: ${{ matrix.model.name != 'fairchem' }}
if: ${{ matrix.model.name != 'fairchem-legacy' }}
run: |
uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system
uv pip install -e ".[test,${{ matrix.model.name }}]" --resolution=${{ matrix.version.resolution }} --system

- name: Run ${{ matrix.model.test_path }} tests
env:
HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
run: |
if [ "${{ matrix.model.name }}" == "fairchem" ]; then
if [[ "${{ matrix.model.name }}" == *"fairchem"* ]]; then
uv pip install "huggingface_hub[cli]" --system
huggingface-cli login --token "$HF_TOKEN"
fi
pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }}
pytest -vv -ra -rs --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }}

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ metatomic = ["metatomic-torch>=0.1.1,<0.2", "metatrain[pet]==2025.7"]
orb = ["orb-models>=0.5.2"]
sevenn = ["sevenn>=0.11.0"]
graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.12"]
fairchem = ["fairchem-core>=2.2.0"]
docs = [
"autodoc_pydantic==2.2.0",
"furo==2024.8.6",
Expand Down Expand Up @@ -142,6 +143,14 @@ testpaths = ["tests"]
# make these dependencies mutually exclusive since they use incompatible e3nn versions
# see https://docs.astral.sh/uv/concepts/projects/config/#conflicting-dependencies for more details
conflicts = [
[
{ extra = "fairchem" },
{ extra = "graphpes" },
],
[
{ extra = "fairchem" },
{ extra = "mace" },
],
[
{ extra = "graphpes" },
{ extra = "mattersim" },
Expand Down
15 changes: 13 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -48,7 +49,12 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel:
@pytest.fixture
def ase_mace_mpa() -> "MACECalculator":
"""Provides an ASE MACECalculator instance using mace_mp."""
from mace.calculators.foundations_models import mace_mp
try:
from mace.calculators.foundations_models import mace_mp
except (ImportError, ModuleNotFoundError):
pytest.skip(
f"MACE not installed: {traceback.format_exc()}", allow_module_level=True
)

# Ensure dtype matches the one used in the torchsim fixture (float64)
return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64")
Expand All @@ -57,7 +63,12 @@ def ase_mace_mpa() -> "MACECalculator":
@pytest.fixture
def torchsim_mace_mpa() -> MaceModel:
"""Provides a MACE MP model instance for the optimizer tests."""
from mace.calculators.foundations_models import mace_mp
try:
from mace.calculators.foundations_models import mace_mp
except (ImportError, ModuleNotFoundError):
pytest.skip(
f"MACE not installed: {traceback.format_exc()}", allow_module_level=True
)

# Use float64 for potentially higher precision needed in optimization
dtype = getattr(torch, dtype_str := "float64")
Expand Down
6 changes: 5 additions & 1 deletion tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import traceback

import pytest
import torch

Expand All @@ -14,7 +16,9 @@
from torch_sim.models.fairchem import FairChemModel

except ImportError:
pytest.skip("FairChem not installed", allow_module_level=True)
pytest.skip(
f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True
)


@pytest.fixture
Expand Down
102 changes: 102 additions & 0 deletions tests/models/test_fairchem_legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import traceback

import pytest
import torch

from tests.models.conftest import (
consistency_test_simstate_fixtures,
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)


try:
from fairchem.core import OCPCalculator
from fairchem.core.models.model_registry import model_name_to_local_file
from huggingface_hub.utils._auth import get_token

from torch_sim.models.fairchem_legacy import FairChemV1Model

except ImportError:
pytest.skip(
f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True
)


@pytest.fixture(scope="session")
def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str:
tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints")
model_name = "EquiformerV2-31M-S2EF-OC20-All+MD"
return model_name_to_local_file(model_name, local_cache=str(tmp_path))


@pytest.fixture
def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemV1Model:
cpu = device.type == "cpu"
return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=True)


@pytest.fixture
def eqv2_oc20_model_non_pbc(
model_path_oc20: str, device: torch.device
) -> FairChemV1Model:
cpu = device.type == "cpu"
return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=False)


if get_token():

@pytest.fixture(scope="session")
def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str:
tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints")
model_name = "EquiformerV2-31M-OMAT24-MP-sAlex"
return model_name_to_local_file(model_name, local_cache=str(tmp_path))

@pytest.fixture
def eqv2_omat24_model_pbc(
model_path_omat24: str, device: torch.device
) -> FairChemV1Model:
cpu = device.type == "cpu"
return FairChemV1Model(model=model_path_omat24, cpu=cpu, seed=0, pbc=True)


@pytest.fixture
def ocp_calculator(model_path_oc20: str) -> OCPCalculator:
return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0)


test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test(
test_name="fairchem_ocp",
model_fixture_name="eqv2_oc20_model_pbc",
calculator_fixture_name="ocp_calculator",
sim_state_names=consistency_test_simstate_fixtures[:-1],
energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models
energy_atol=5e-4,
force_rtol=5e-4,
force_atol=5e-4,
stress_rtol=5e-4,
stress_atol=5e-4,
)

test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test(
test_name="fairchem_non_pbc_benzene",
model_fixture_name="eqv2_oc20_model_non_pbc",
calculator_fixture_name="ocp_calculator",
sim_state_names=["benzene_sim_state"],
energy_rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models
energy_atol=5e-4,
force_rtol=5e-4,
force_atol=5e-4,
stress_rtol=5e-4,
stress_atol=5e-4,
)


# Skip this test due to issues with how the older models
# handled supercells (see related issue here: https://github.com/facebookresearch/fairchem/issues/428)

test_fairchem_ocp_model_outputs = pytest.mark.skipif(
os.environ.get("HF_TOKEN") is None,
reason="Issues in graph construction of older models",
)(make_validate_model_outputs_test(model_fixture_name="eqv2_omat24_model_pbc"))
6 changes: 5 additions & 1 deletion tests/models/test_graphpes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import traceback

import pytest
import torch
from ase.build import bulk, molecule
Expand All @@ -16,7 +18,9 @@
from graph_pes.interfaces import mace_mp
from graph_pes.models import LennardJones, SchNet, TensorNet, ZEmbeddingNequIP
except ImportError:
pytest.skip("graph-pes not installed", allow_module_level=True)
pytest.skip(
f"graph-pes not installed: {traceback.format_exc()}", allow_module_level=True
)


@pytest.fixture
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_mace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import traceback

import pytest
import torch
from ase.atoms import Atoms
Expand All @@ -17,7 +19,7 @@

from torch_sim.models.mace import MaceModel
except (ImportError, ValueError):
pytest.skip("MACE not installed", allow_module_level=True)
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True)


mace_model = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True)
Expand Down
6 changes: 5 additions & 1 deletion tests/models/test_mattersim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# codespell-ignore: convertor

import traceback

import ase.spacegroup
import ase.units
import pytest
Expand All @@ -18,7 +20,9 @@
from torch_sim.models.mattersim import MatterSimModel

except ImportError:
pytest.skip("mattersim not installed", allow_module_level=True)
pytest.skip(
f"mattersim not installed: {traceback.format_exc()}", allow_module_level=True
)


@pytest.fixture
Expand Down
6 changes: 5 additions & 1 deletion tests/models/test_metatomic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import traceback

import pytest
import torch

Expand All @@ -14,7 +16,9 @@

from torch_sim.models.metatomic import MetatomicModel
except ImportError:
pytest.skip("metatomic not installed", allow_module_level=True)
pytest.skip(
f"metatomic not installed: {traceback.format_exc()}", allow_module_level=True
)


@pytest.fixture
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_orb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import traceback

import pytest
import torch

Expand All @@ -14,7 +16,7 @@

from torch_sim.models.orb import OrbModel
except ImportError:
pytest.skip("ORB not installed", allow_module_level=True)
pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True)


@pytest.fixture
Expand Down
6 changes: 5 additions & 1 deletion tests/models/test_sevennet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import traceback

import pytest
import torch

Expand All @@ -16,7 +18,9 @@
from torch_sim.models.sevennet import SevenNetModel

except ImportError:
pytest.skip("sevenn not installed", allow_module_level=True)
pytest.skip(
f"sevenn not installed: {traceback.format_exc()}", allow_module_level=True
)


@pytest.fixture
Expand Down
4 changes: 3 additions & 1 deletion tests/test_elastic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import traceback

import pytest
import torch

Expand All @@ -20,7 +22,7 @@

from torch_sim.models.mace import MaceModel
except ImportError:
pytest.skip("MACE not installed", allow_module_level=True)
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True)


def test_get_strain_zero_deformation(cu_sim_state: ts.SimState) -> None:
Expand Down
Loading