Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions tests/models/test_sum_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Tests for the SumModel composite model."""

import pytest
import torch

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from torch_sim.models.interface import SumModel, validate_model_outputs
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.morse import MorseModel


@pytest.fixture
def lj_model_a() -> LennardJonesModel:
return LennardJonesModel(
sigma=3.405,
epsilon=0.0104,
cutoff=2.5 * 3.405,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def lj_model_b() -> LennardJonesModel:
return LennardJonesModel(
sigma=2.0,
epsilon=0.005,
cutoff=5.0,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def morse_model() -> MorseModel:
return MorseModel(
sigma=2.55,
epsilon=0.436,
alpha=1.359,
cutoff=6.0,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumModel:
return SumModel(lj_model_a, morse_model)


def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None:
with pytest.raises(ValueError, match="at least two"):
SumModel(lj_model_a)


def test_sum_model_device_mismatch() -> None:
m1 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, device=torch.device("cpu"))
m2 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, device=torch.device("cpu"))
object.__setattr__(m2, "_device", torch.device("meta"))
with pytest.raises(ValueError, match="Device mismatch"):
SumModel(m1, m2)


def test_sum_model_dtype_mismatch() -> None:
m1 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, dtype=torch.float64)
m2 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, dtype=torch.float32)
with pytest.raises(ValueError, match="Dtype mismatch"):
SumModel(m1, m2)


def test_sum_model_properties(sum_model: SumModel) -> None:
assert sum_model.device == DEVICE
assert sum_model.dtype == DTYPE
assert sum_model.compute_stress is True
assert sum_model.compute_forces is True


def test_sum_model_energy_summation(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
si_sim_state: ts.SimState,
) -> None:
lj_out = lj_model_a(si_sim_state)
morse_out = morse_model(si_sim_state)
sum_out = sum_model(si_sim_state)
expected_energy = lj_out["energy"] + morse_out["energy"]
torch.testing.assert_close(sum_out["energy"], expected_energy)


def test_sum_model_forces_summation(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
si_sim_state: ts.SimState,
) -> None:
lj_out = lj_model_a(si_sim_state)
morse_out = morse_model(si_sim_state)
sum_out = sum_model(si_sim_state)
expected_forces = lj_out["forces"] + morse_out["forces"]
torch.testing.assert_close(sum_out["forces"], expected_forces)


def test_sum_model_stress_summation(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
si_sim_state: ts.SimState,
) -> None:
lj_out = lj_model_a(si_sim_state)
morse_out = morse_model(si_sim_state)
sum_out = sum_model(si_sim_state)
expected_stress = lj_out["stress"] + morse_out["stress"]
torch.testing.assert_close(sum_out["stress"], expected_stress)


def test_sum_model_batched(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
si_double_sim_state: ts.SimState,
) -> None:
lj_out = lj_model_a(si_double_sim_state)
morse_out = morse_model(si_double_sim_state)
sum_out = sum_model(si_double_sim_state)
torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + morse_out["energy"])
torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + morse_out["forces"])
torch.testing.assert_close(sum_out["stress"], lj_out["stress"] + morse_out["stress"])


def test_sum_model_three_models(
lj_model_a: LennardJonesModel,
lj_model_b: LennardJonesModel,
morse_model: MorseModel,
si_sim_state: ts.SimState,
) -> None:
triple = SumModel(lj_model_a, lj_model_b, morse_model)
a_out = lj_model_a(si_sim_state)
b_out = lj_model_b(si_sim_state)
c_out = morse_model(si_sim_state)
sum_out = triple(si_sim_state)
torch.testing.assert_close(
sum_out["energy"], a_out["energy"] + b_out["energy"] + c_out["energy"]
)


def test_sum_model_compute_stress_setter(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
sm = SumModel(lj_model_a, morse_model)
assert sm.compute_stress is True
sm.compute_stress = False
assert sm.compute_stress is False


def test_sum_model_compute_forces_setter(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
sm = SumModel(lj_model_a, morse_model)
sm.compute_forces = False
assert sm.compute_forces is False


def test_sum_model_memory_scales_with(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
sm = SumModel(lj_model_a, morse_model)
assert sm.memory_scales_with == "n_atoms_x_density"


def test_sum_model_force_conservation(
sum_model: SumModel, si_double_sim_state: ts.SimState
) -> None:
results = sum_model(si_double_sim_state)
for sys_idx in range(si_double_sim_state.n_systems):
mask = si_double_sim_state.system_idx == sys_idx
assert torch.allclose(
results["forces"][mask].sum(dim=0),
torch.zeros(3, dtype=DTYPE),
atol=1e-10,
)


def test_sum_model_validate_outputs(sum_model: SumModel) -> None:
validate_model_outputs(sum_model, DEVICE, DTYPE, check_detached=True)


def test_sum_model_retain_graph(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
sm = SumModel(lj_model_a, morse_model)
assert sm.retain_graph is False
sm.retain_graph = True
assert lj_model_a.retain_graph is True
assert morse_model.retain_graph is True
assert sm.retain_graph is True
139 changes: 137 additions & 2 deletions torch_sim/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,29 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs):
compute_stress property, as some integrators require stress calculations.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch

import torch_sim as ts
from torch_sim.state import SimState
from torch_sim.typing import MemoryScaling


if TYPE_CHECKING:
from torch_sim.state import SimState
from torch_sim.typing import MemoryScaling


VALIDATE_ATOL = 1e-4

_MEMORY_SCALING_PRIORITY: dict[MemoryScaling, int] = {
"n_atoms": 0,
"n_atoms_x_density": 1,
"n_edges": 2,
}


class ModelInterface(torch.nn.Module, ABC):
"""Abstract base class for all simulation models in TorchSim.
Expand Down Expand Up @@ -171,6 +183,129 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
"""


class SumModel(ModelInterface):
"""Additive composition of multiple :class:`ModelInterface` models.

Calls each child model's :meth:`forward` and sums the output tensors
key-by-key, so energies, forces, and stresses are combined additively.
This is the standard way to layer a dispersion correction (e.g. DFT-D3),
an Ewald electrostatic term, or a local pair potential on top of a primary
machine-learning potential.

Args:
models: Two or more :class:`ModelInterface` instances that share the
same ``device`` and ``dtype``.

Raises:
ValueError: If fewer than two models are given or if ``device``/``dtype``
do not match across all models.

Examples:
```py
sum_model = SumModel(mace_model, d3_model)
output = sum_model(sim_state)
```
"""

def __init__(self, *models: ModelInterface) -> None:
"""Initialize the sum model.

Args:
models: Two or more :class:`ModelInterface` instances. All must
share the same ``device`` and ``dtype``.
"""
super().__init__()
if len(models) < 2:
raise ValueError("SumModel requires at least two child models")
first = models[0]
for i, m in enumerate(models[1:], start=1):
if m.device != first.device:
raise ValueError(
f"Device mismatch: model 0 has {first.device}, "
f"model {i} has {m.device}"
)
if m.dtype != first.dtype:
raise ValueError(
f"Dtype mismatch: model 0 has {first.dtype}, model {i} has {m.dtype}"
)
self.models = torch.nn.ModuleList(models)
self._device = first.device
self._dtype = first.dtype
self._compute_stress = all(m.compute_stress for m in models)
self._compute_forces = all(m.compute_forces for m in models)

def _children(self) -> list[ModelInterface]:
"""Return child models with proper typing for static analysis."""
return list(self.models.children()) # type: ignore[return-value]

@ModelInterface.compute_stress.setter
def compute_stress(self, value: bool) -> None: # noqa: FBT001
"""Propagate ``compute_stress`` to all child models that support it."""
for m in self._children():
try:
m.compute_stress = value
except NotImplementedError:
if value:
raise
self._compute_stress = value

@ModelInterface.compute_forces.setter
def compute_forces(self, value: bool) -> None: # noqa: FBT001
"""Propagate ``compute_forces`` to all child models that support it."""
for m in self._children():
try:
m.compute_forces = value
except NotImplementedError:
if value:
raise
self._compute_forces = value

@property
def retain_graph(self) -> bool:
"""Whether any child model retains the computation graph."""
return any(getattr(m, "retain_graph", False) for m in self._children())

@retain_graph.setter
def retain_graph(self, value: bool) -> None:
Comment thread
abhijeetgangan marked this conversation as resolved.
for m in self._children():
if hasattr(m, "retain_graph"):
m.retain_graph = value # type: ignore[union-attr]

@property
def memory_scales_with(self) -> MemoryScaling:
"""Most conservative memory-scaling among all child models."""
best: MemoryScaling = "n_atoms"
for m in self._children():
scaling = m.memory_scales_with
if _MEMORY_SCALING_PRIORITY[scaling] > _MEMORY_SCALING_PRIORITY[best]:
best = scaling
return best

def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
"""Sum the outputs of all child models.

Each child model is called with the same ``state`` and ``**kwargs``.
Output tensors that appear in multiple children are summed element-wise;
keys unique to a single child are passed through unchanged.

Args:
state: Simulation state (see :class:`ModelInterface`).
**kwargs: Forwarded to every child model.

Returns:
Combined output dictionary with summed tensors.
"""
combined: dict[str, torch.Tensor] = {}
for model in self._children():
output = model(state, **kwargs)
for key, tensor in output.items():
if key in combined:
combined[key] = combined[key] + tensor
else:
combined[key] = tensor
return combined


def _check_output_detached(
output: dict[str, torch.Tensor], model: ModelInterface
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions torch_sim/neighbors/vesin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
try:
from vesin import NeighborList as VesinNeighborList
except ImportError:
VesinNeighborList = None # type: ignore[assignment]
VesinNeighborList = None


try:
from vesin.torch import NeighborList as VesinNeighborListTorch
except ImportError:
VesinNeighborListTorch = None # ty:ignore[invalid-assignment]
VesinNeighborListTorch = None

VESIN_AVAILABLE = VesinNeighborList is not None
VESIN_TORCHSCRIPT_AVAILABLE = VesinNeighborListTorch is not None
Expand Down
Loading