diff --git a/tests/models/test_sum_model.py b/tests/models/test_sum_model.py new file mode 100644 index 00000000..63d51201 --- /dev/null +++ b/tests/models/test_sum_model.py @@ -0,0 +1,203 @@ +"""Tests for the SumModel composite model.""" + +import pytest +import torch + +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from torch_sim.models.interface import SumModel, validate_model_outputs +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.models.morse import MorseModel + + +@pytest.fixture +def lj_model_a() -> LennardJonesModel: + return LennardJonesModel( + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def lj_model_b() -> LennardJonesModel: + return LennardJonesModel( + sigma=2.0, + epsilon=0.005, + cutoff=5.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def morse_model() -> MorseModel: + return MorseModel( + sigma=2.55, + epsilon=0.436, + alpha=1.359, + cutoff=6.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumModel: + return SumModel(lj_model_a, morse_model) + + +def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None: + with pytest.raises(ValueError, match="at least two"): + SumModel(lj_model_a) + + +def test_sum_model_device_mismatch() -> None: + m1 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, device=torch.device("cpu")) + m2 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, device=torch.device("cpu")) + object.__setattr__(m2, "_device", torch.device("meta")) + with pytest.raises(ValueError, match="Device mismatch"): + SumModel(m1, m2) + + +def test_sum_model_dtype_mismatch() -> None: + m1 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, dtype=torch.float64) + m2 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, dtype=torch.float32) + with pytest.raises(ValueError, match="Dtype mismatch"): + SumModel(m1, m2) + + +def test_sum_model_properties(sum_model: SumModel) -> None: + assert sum_model.device == DEVICE + assert sum_model.dtype == DTYPE + assert sum_model.compute_stress is True + assert sum_model.compute_forces is True + + +def test_sum_model_energy_summation( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_sim_state) + morse_out = morse_model(si_sim_state) + sum_out = sum_model(si_sim_state) + expected_energy = lj_out["energy"] + morse_out["energy"] + torch.testing.assert_close(sum_out["energy"], expected_energy) + + +def test_sum_model_forces_summation( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_sim_state) + morse_out = morse_model(si_sim_state) + sum_out = sum_model(si_sim_state) + expected_forces = lj_out["forces"] + morse_out["forces"] + torch.testing.assert_close(sum_out["forces"], expected_forces) + + +def test_sum_model_stress_summation( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_sim_state) + morse_out = morse_model(si_sim_state) + sum_out = sum_model(si_sim_state) + expected_stress = lj_out["stress"] + morse_out["stress"] + torch.testing.assert_close(sum_out["stress"], expected_stress) + + +def test_sum_model_batched( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_double_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_double_sim_state) + morse_out = morse_model(si_double_sim_state) + sum_out = sum_model(si_double_sim_state) + torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + morse_out["energy"]) + torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + morse_out["forces"]) + torch.testing.assert_close(sum_out["stress"], lj_out["stress"] + morse_out["stress"]) + + +def test_sum_model_three_models( + lj_model_a: LennardJonesModel, + lj_model_b: LennardJonesModel, + morse_model: MorseModel, + si_sim_state: ts.SimState, +) -> None: + triple = SumModel(lj_model_a, lj_model_b, morse_model) + a_out = lj_model_a(si_sim_state) + b_out = lj_model_b(si_sim_state) + c_out = morse_model(si_sim_state) + sum_out = triple(si_sim_state) + torch.testing.assert_close( + sum_out["energy"], a_out["energy"] + b_out["energy"] + c_out["energy"] + ) + + +def test_sum_model_compute_stress_setter( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + assert sm.compute_stress is True + sm.compute_stress = False + assert sm.compute_stress is False + + +def test_sum_model_compute_forces_setter( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + sm.compute_forces = False + assert sm.compute_forces is False + + +def test_sum_model_memory_scales_with( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + assert sm.memory_scales_with == "n_atoms_x_density" + + +def test_sum_model_force_conservation( + sum_model: SumModel, si_double_sim_state: ts.SimState +) -> None: + results = sum_model(si_double_sim_state) + for sys_idx in range(si_double_sim_state.n_systems): + mask = si_double_sim_state.system_idx == sys_idx + assert torch.allclose( + results["forces"][mask].sum(dim=0), + torch.zeros(3, dtype=DTYPE), + atol=1e-10, + ) + + +def test_sum_model_validate_outputs(sum_model: SumModel) -> None: + validate_model_outputs(sum_model, DEVICE, DTYPE, check_detached=True) + + +def test_sum_model_retain_graph( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + assert sm.retain_graph is False + sm.retain_graph = True + assert lj_model_a.retain_graph is True + assert morse_model.retain_graph is True + assert sm.retain_graph is True diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 8aa6bb5e..885627b1 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -26,17 +26,29 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): compute_stress property, as some integrators require stress calculations. """ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import torch import torch_sim as ts -from torch_sim.state import SimState -from torch_sim.typing import MemoryScaling + + +if TYPE_CHECKING: + from torch_sim.state import SimState + from torch_sim.typing import MemoryScaling VALIDATE_ATOL = 1e-4 +_MEMORY_SCALING_PRIORITY: dict[MemoryScaling, int] = { + "n_atoms": 0, + "n_atoms_x_density": 1, + "n_edges": 2, +} + class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. @@ -171,6 +183,129 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """ +class SumModel(ModelInterface): + """Additive composition of multiple :class:`ModelInterface` models. + + Calls each child model's :meth:`forward` and sums the output tensors + key-by-key, so energies, forces, and stresses are combined additively. + This is the standard way to layer a dispersion correction (e.g. DFT-D3), + an Ewald electrostatic term, or a local pair potential on top of a primary + machine-learning potential. + + Args: + models: Two or more :class:`ModelInterface` instances that share the + same ``device`` and ``dtype``. + + Raises: + ValueError: If fewer than two models are given or if ``device``/``dtype`` + do not match across all models. + + Examples: + ```py + sum_model = SumModel(mace_model, d3_model) + output = sum_model(sim_state) + ``` + """ + + def __init__(self, *models: ModelInterface) -> None: + """Initialize the sum model. + + Args: + models: Two or more :class:`ModelInterface` instances. All must + share the same ``device`` and ``dtype``. + """ + super().__init__() + if len(models) < 2: + raise ValueError("SumModel requires at least two child models") + first = models[0] + for i, m in enumerate(models[1:], start=1): + if m.device != first.device: + raise ValueError( + f"Device mismatch: model 0 has {first.device}, " + f"model {i} has {m.device}" + ) + if m.dtype != first.dtype: + raise ValueError( + f"Dtype mismatch: model 0 has {first.dtype}, model {i} has {m.dtype}" + ) + self.models = torch.nn.ModuleList(models) + self._device = first.device + self._dtype = first.dtype + self._compute_stress = all(m.compute_stress for m in models) + self._compute_forces = all(m.compute_forces for m in models) + + def _children(self) -> list[ModelInterface]: + """Return child models with proper typing for static analysis.""" + return list(self.models.children()) # type: ignore[return-value] + + @ModelInterface.compute_stress.setter + def compute_stress(self, value: bool) -> None: # noqa: FBT001 + """Propagate ``compute_stress`` to all child models that support it.""" + for m in self._children(): + try: + m.compute_stress = value + except NotImplementedError: + if value: + raise + self._compute_stress = value + + @ModelInterface.compute_forces.setter + def compute_forces(self, value: bool) -> None: # noqa: FBT001 + """Propagate ``compute_forces`` to all child models that support it.""" + for m in self._children(): + try: + m.compute_forces = value + except NotImplementedError: + if value: + raise + self._compute_forces = value + + @property + def retain_graph(self) -> bool: + """Whether any child model retains the computation graph.""" + return any(getattr(m, "retain_graph", False) for m in self._children()) + + @retain_graph.setter + def retain_graph(self, value: bool) -> None: + for m in self._children(): + if hasattr(m, "retain_graph"): + m.retain_graph = value # type: ignore[union-attr] + + @property + def memory_scales_with(self) -> MemoryScaling: + """Most conservative memory-scaling among all child models.""" + best: MemoryScaling = "n_atoms" + for m in self._children(): + scaling = m.memory_scales_with + if _MEMORY_SCALING_PRIORITY[scaling] > _MEMORY_SCALING_PRIORITY[best]: + best = scaling + return best + + def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: + """Sum the outputs of all child models. + + Each child model is called with the same ``state`` and ``**kwargs``. + Output tensors that appear in multiple children are summed element-wise; + keys unique to a single child are passed through unchanged. + + Args: + state: Simulation state (see :class:`ModelInterface`). + **kwargs: Forwarded to every child model. + + Returns: + Combined output dictionary with summed tensors. + """ + combined: dict[str, torch.Tensor] = {} + for model in self._children(): + output = model(state, **kwargs) + for key, tensor in output.items(): + if key in combined: + combined[key] = combined[key] + tensor + else: + combined[key] = tensor + return combined + + def _check_output_detached( output: dict[str, torch.Tensor], model: ModelInterface ) -> None: diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index 009fe9bb..16648950 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -12,13 +12,13 @@ try: from vesin import NeighborList as VesinNeighborList except ImportError: - VesinNeighborList = None # type: ignore[assignment] + VesinNeighborList = None try: from vesin.torch import NeighborList as VesinNeighborListTorch except ImportError: - VesinNeighborListTorch = None # ty:ignore[invalid-assignment] + VesinNeighborListTorch = None VESIN_AVAILABLE = VesinNeighborList is not None VESIN_TORCHSCRIPT_AVAILABLE = VesinNeighborListTorch is not None