diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 4932994d9..a37d2d667 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -1,7 +1,7 @@ .. _engine-torch-sim: -torch-sim -========= +TorchSim +======== .. list-table:: :header-rows: 1 @@ -25,8 +25,23 @@ For the full TorchSim documentation, see https://torchsim.github.io/torch-sim/. Supported model outputs ^^^^^^^^^^^^^^^^^^^^^^^ -Only the :ref:`energy ` output is supported. Forces and stresses -are derived via autograd. +The :ref:`energy ` output is the primary output. Forces and +stresses are derived via autograd by default. The wrapper also supports: + +- **Non-conservative forces/stress**: use direct prediction of gradients instead + of autograd (``non_conservative=True``) +- **Energy uncertainty**: per-atom uncertainty warnings when the model provides + an ``energy_uncertainty`` output +- **Additional outputs**: request arbitrary extra model outputs via + ``additional_outputs``; results are stored as + :py:class:`metatensor.torch.TensorMap` in the + :py:attr:`~metatomic_torchsim.MetatomicModel.additional_outputs` attribute + +See the :py:class:`~metatomic_torchsim.MetatomicModel` API documentation below +for details on all parameters, and the tutorials for worked examples: + +- :ref:`torchsim-getting-started` -- loading a model and running NVE dynamics +- :ref:`torchsim-batched` -- evaluating multiple systems in a single call How to use the code ^^^^^^^^^^^^^^^^^^^ diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py new file mode 100644 index 000000000..f39a5f10e --- /dev/null +++ b/python/examples/5-torchsim-getting-started.py @@ -0,0 +1,192 @@ +""" +.. _torchsim-getting-started: + +Getting started with TorchSim +============================= + +This tutorial walks through running a short NVE molecular dynamics +simulation with a metatomic model and `TorchSim +`_. +""" + +# %% +# +# Prerequisites +# ------------- +# +# Install the integration package and its dependencies: +# +# .. code-block:: bash +# +# pip install metatomic-torchsim +# +# We start by importing the modules we need: + +from typing import Dict, List, Optional + +import ase.build +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +import metatomic.torch as mta +from metatomic_torchsim import MetatomicModel + + +# %% +# +# Export a simple model +# --------------------- +# +# For this tutorial we create and export a minimal model that predicts +# energy as a (trivial) function of atomic positions. The energy must +# depend on positions so that forces can be computed via autograd. +# In practice you would use a pre-trained model loaded from a file. + + +class HarmonicEnergy(torch.nn.Module): + """A minimal model: harmonic restraint around initial positions.""" + + def __init__(self, k: float = 0.1): + super().__init__() + self.k = k + + def forward( + self, + systems: List[mta.System], + outputs: Dict[str, mta.ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + energies: List[torch.Tensor] = [] + for system in systems: + # energy = k * sum(positions^2) -- differentiable w.r.t. positions + e = self.k * torch.sum(system.positions**2) + energies.append(e.reshape(1, 1)) + + energy = torch.cat(energies, dim=0) + + block = TensorBlock( + values=energy, + samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + return { + "energy": TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block]) + } + + +# %% +# +# Build an ``AtomisticModel`` wrapping the raw module: + +raw_model = HarmonicEnergy(k=0.1) +capabilities = mta.ModelCapabilities( + length_unit="Angstrom", + atomic_types=[14], # Silicon + interaction_range=0.0, + outputs={"energy": mta.ModelOutput(quantity="energy", unit="eV")}, + supported_devices=["cpu"], + dtype="float64", +) + +atomistic_model = mta.AtomisticModel( + raw_model.eval(), mta.ModelMetadata(), capabilities +) + +# %% +# +# Load the model +# -------------- +# +# Wrap the model with :py:class:`~metatomic_torchsim.MetatomicModel`. +# You can pass an ``AtomisticModel`` directly, or a path to a saved +# ``.pt`` file: + +model = MetatomicModel(atomistic_model, device="cpu") + +# %% +# +# The wrapper detects the model's dtype and supported devices +# automatically. Pass ``device="cuda"`` to run on GPU when available. + +print("dtype:", model.dtype) +print("device:", model.device) + +# %% +# +# Build a simulation state +# ------------------------ +# +# TorchSim works with ``SimState`` objects. Convert ASE ``Atoms`` using +# ``torch_sim.initialize_state``: + +import torch_sim as ts # noqa: E402 + + +atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) +sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + +print("Number of atoms:", sim_state.n_atoms) + +# %% +# +# Evaluate the model +# ------------------ +# +# Call the model on the simulation state to get energies, forces, and +# stresses: + +results = model(sim_state) + +print("Energy:", results["energy"]) # shape [1] +print("Forces shape:", results["forces"].shape) # shape [n_atoms, 3] +print("Stress shape:", results["stress"].shape) # shape [1, 3, 3] + +# %% +# +# Run NVE dynamics +# ---------------- +# +# Use TorchSim's NVE (Velocity Verlet) integrator to run a short trajectory. +# ``nve_init`` samples momenta from a Maxwell-Boltzmann distribution at the +# given temperature, and ``nve_step`` advances by one timestep: + +import matplotlib.pyplot as plt # noqa: E402 +from torch_sim.integrators import nve_init, nve_step # noqa: E402 +from torch_sim.units import MetalUnits # noqa: E402 + + +sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + +# Initialize NVE state with momenta at 300 K (in eV units) +kT = 300.0 * MetalUnits.temperature # kelvin -> eV +md_state = nve_init(sim_state, model, kT=kT) + +energies = [] +steps = [] +dt = 1.0 # femtoseconds + +for step in range(50): + md_state = nve_step(md_state, model, dt=dt) + energies.append(md_state.energy.sum().item()) + steps.append(step) + +plt.plot(steps, energies) +plt.xlabel("Step") +plt.ylabel("Potential energy (eV)") +plt.title("NVE dynamics -- potential energy vs step") +plt.tight_layout() +plt.show() + + +# %% +# +# .. note:: +# +# With a real interatomic potential the total energy would stay approximately +# constant in an NVE simulation, which serves as a basic sanity check. +# +# Next steps +# ---------- +# +# - :ref:`torchsim-batched` explains running multiple systems at once diff --git a/python/examples/6-torchsim-batched.py b/python/examples/6-torchsim-batched.py new file mode 100644 index 000000000..6a2e4730d --- /dev/null +++ b/python/examples/6-torchsim-batched.py @@ -0,0 +1,176 @@ +""" +.. _torchsim-batched: + +Batched simulations with TorchSim +================================= + +TorchSim supports batching multiple systems into a single ``SimState`` +for efficient parallel evaluation on GPU. +:py:class:`~metatomic_torchsim.MetatomicModel` handles this +transparently. +""" + +# %% +# +# Setup +# ----- +# +# We reuse the same minimal model from :ref:`torchsim-getting-started`. +# The model must produce differentiable energy so that forces/stress can +# be computed via autograd. + +from typing import Dict, List, Optional + +import ase.build +import matplotlib.pyplot as plt +import torch +import torch_sim as ts +from metatensor.torch import Labels, TensorBlock, TensorMap + +import metatomic.torch as mta +from metatomic_torchsim import MetatomicModel + + +class HarmonicEnergy(torch.nn.Module): + """Harmonic restraint: E = k * sum(positions^2).""" + + def __init__(self, k: float = 0.1): + super().__init__() + self.k = k + + def forward( + self, + systems: List[mta.System], + outputs: Dict[str, mta.ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + energies: List[torch.Tensor] = [] + for system in systems: + e = self.k * torch.sum(system.positions**2) + energies.append(e.reshape(1, 1)) + + energy = torch.cat(energies, dim=0) + block = TensorBlock( + values=energy, + samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + return { + "energy": TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block]) + } + + +capabilities = mta.ModelCapabilities( + length_unit="Angstrom", + atomic_types=[13, 29], # Al, Cu + interaction_range=0.0, + outputs={"energy": mta.ModelOutput(quantity="energy", unit="eV")}, + supported_devices=["cpu"], + dtype="float64", +) + +atomistic_model = mta.AtomisticModel( + HarmonicEnergy(0.1).eval(), mta.ModelMetadata(), capabilities +) + +model = MetatomicModel(atomistic_model, device="cpu") + + +# %% +# +# Creating a batched state +# ------------------------ +# +# Pass a list of ASE ``Atoms`` objects to ``initialize_state``: + +atoms_list = [ + ase.build.bulk("Cu", "fcc", a=3.6, cubic=True), + ase.build.bulk("Cu", "fcc", a=3.65, cubic=True), + ase.build.bulk("Al", "fcc", a=4.05, cubic=True), +] + +sim_state = ts.initialize_state(atoms_list, device=model.device, dtype=model.dtype) +print("Total atoms in batch:", sim_state.n_atoms) + +# %% +# +# Evaluating the batch +# -------------------- +# +# A single forward call evaluates all systems: + +results = model(sim_state) + +print("Energy shape:", results["energy"].shape) # [n_systems] +print("Forces shape:", results["forces"].shape) # [n_total_atoms, 3] +print("Stress shape:", results["stress"].shape) # [n_systems, 3, 3] + +# %% +# +# The output shapes reflect the batch: +# +# - ``results["energy"]`` has shape ``[n_systems]`` -- one energy per system +# - ``results["forces"]`` has shape ``[n_total_atoms, 3]`` -- all atoms +# concatenated +# - ``results["stress"]`` has shape ``[n_systems, 3, 3]`` -- one 3x3 tensor +# per system + +print("Per-system energies:", results["energy"]) + +# %% +# +# How ``system_idx`` works +# ------------------------ +# +# ``SimState`` tracks which atom belongs to which system via the +# ``system_idx`` tensor. For three 4-atom systems, ``system_idx`` looks +# like: + +print("system_idx:", sim_state.system_idx) + +# %% +# +# ``MetatomicModel.forward`` uses this to split the batched positions and +# types into per-system ``System`` objects before calling the underlying +# model. +# +# Batch consistency +# ----------------- +# +# Energies computed in a batch match those computed individually. +# This is guaranteed because each system gets its own neighbor list and +# independent evaluation: + +individual_energies = [] +for atoms in atoms_list: + state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + res = model(state) + individual_energies.append(res["energy"].item()) + +print("Batched: ", [e.item() for e in results["energy"]]) +print("Individual:", individual_energies) + +plt.scatter(individual_energies, results["energy"].cpu().numpy()) +plt.plot( + [min(individual_energies), max(individual_energies)], + [min(individual_energies), max(individual_energies)], + "k--", +) +plt.xlabel("Individual energies") +plt.ylabel("Batched energies") +plt.show() + +# %% +# +# Performance considerations +# -------------------------- +# +# Batching is most beneficial on GPU, where the neighbor list computation +# and model forward pass can run in parallel across systems. On CPU, the +# speedup comes from reduced Python overhead (one call instead of N). +# +# For very large systems or many small ones, adjust the batch size to fit +# in GPU memory. TorchSim does not impose a maximum batch size, but each +# system gets its own neighbor list, so memory scales with the sum of +# per-system sizes. diff --git a/python/metatomic_torchsim/CHANGELOG.md b/python/metatomic_torchsim/CHANGELOG.md index d74929e06..027e53b5e 100644 --- a/python/metatomic_torchsim/CHANGELOG.md +++ b/python/metatomic_torchsim/CHANGELOG.md @@ -18,3 +18,13 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - `metatomic-torchsim` is now a standalone package, containing the TorchSim integration for metatomic models. + +### Added + +- Support for output variants via the `variants` parameter, matching the ASE + calculator's variant selection +- Non-conservative forces and stresses via `non_conservative=True`, reading + model outputs directly instead of autograd +- Per-atom energy uncertainty warnings via `uncertainty_threshold`, triggered + when the model provides `energy_uncertainty` with `per_atom=True` +- `additional_outputs` parameter for requesting arbitrary extra model outputs diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 6612c4612..428ff6b8f 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -4,17 +4,20 @@ be used within the torch-sim simulation framework for MD and other simulations. Supports batched computations for multiple systems simultaneously, computing -energies, forces, and stresses via autograd. +energies, forces, and stresses via autograd. Also supports output variants, +non-conservative forces/stress, energy uncertainty warnings, and additional +model outputs. """ import logging import os import pathlib +import warnings from typing import Dict, List, Optional, Union import torch import vesin.metatomic -from metatensor.torch import Labels, TensorBlock +from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( AtomisticModel, @@ -24,6 +27,7 @@ System, load_atomistic_model, pick_device, + pick_output, ) @@ -76,6 +80,10 @@ def __init__( check_consistency: bool = False, compute_forces: bool = True, compute_stress: bool = True, + variants: Optional[Dict[str, Optional[str]]] = None, + non_conservative: bool = False, + uncertainty_threshold: Optional[float] = 0.1, + additional_outputs: Optional[Dict[str, ModelOutput]] = None, ) -> None: """ :param model: Model to use. Accepts a file path to a ``.pt`` saved @@ -89,6 +97,23 @@ def __init__( Useful for debugging but hurts performance. :param compute_forces: Compute atomic forces via autograd. :param compute_stress: Compute stress tensors via the strain trick. + :param variants: Dictionary mapping output names to a variant that should + be used. Setting ``{"energy": "pbe"}`` selects the ``"energy/pbe"`` + output. The energy variant propagates to uncertainty and + non-conservative outputs unless overridden (e.g. + ``{"energy": "pbe", "energy_uncertainty": "r2scan"}`` would select + ``energy/pbe`` and ``energy_uncertainty/r2scan``). + :param non_conservative: If ``True``, the model will be asked to compute + non-conservative forces and stresses. This can afford a speed-up, + potentially at the expense of physical correctness (especially in + molecular dynamics simulations). + :param uncertainty_threshold: Threshold for per-atom energy uncertainty + in eV. When the model supports ``energy_uncertainty`` with + ``per_atom=True``, atoms exceeding this threshold trigger a warning. + Set to ``None`` to disable. + :param additional_outputs: Dictionary of extra :py:class:`ModelOutput` + to request from the model. Results are stored in + :py:attr:`additional_outputs` after each forward call. """ super().__init__() @@ -133,25 +158,138 @@ def __init__( f"unexpected dtype in model capabilities: {capabilities.dtype}" ) - if "energy" not in capabilities.outputs: + # Resolve output keys based on requested variants + variants = variants or {} + default_variant = variants.get("energy") + + resolved_variants = { + key: variants.get(key, default_variant) + for key in [ + "energy", + "energy_uncertainty", + "non_conservative_forces", + "non_conservative_stress", + ] + } + + outputs = capabilities.outputs + + has_energy = any( + "energy" == key or key.startswith("energy/") for key in outputs.keys() + ) + if not has_energy: raise ValueError( "model does not have an 'energy' output. " "Only models with energy outputs can be used with TorchSim." ) + self._energy_key = pick_output("energy", outputs, resolved_variants["energy"]) + + # Uncertainty + has_energy_uq = any("energy_uncertainty" in key for key in outputs.keys()) + if has_energy_uq and uncertainty_threshold is not None: + self._energy_uq_key = pick_output( + "energy_uncertainty", + outputs, + resolved_variants["energy_uncertainty"], + ) + else: + self._energy_uq_key = "energy_uncertainty" + + # Non-conservative outputs + self._non_conservative = non_conservative + if non_conservative: + if ( + "non_conservative_stress" in variants + and "non_conservative_forces" in variants + and ( + (variants["non_conservative_stress"] is None) + != (variants["non_conservative_forces"] is None) + ) + ): + raise ValueError( + "if both 'non_conservative_stress' and " + "'non_conservative_forces' are present in `variants`, they " + "must either be both `None` or both not `None`." + ) + + self._nc_forces_key = pick_output( + "non_conservative_forces", + outputs, + resolved_variants["non_conservative_forces"], + ) + self._nc_stress_key = pick_output( + "non_conservative_stress", + outputs, + resolved_variants["non_conservative_stress"], + ) + else: + self._nc_forces_key = "non_conservative_forces" + self._nc_stress_key = "non_conservative_stress" + + # Additional outputs + if additional_outputs is None: + self._additional_output_requests: Dict[str, ModelOutput] = {} + else: + assert isinstance(additional_outputs, dict) + for name, output in additional_outputs.items(): + assert isinstance(name, str) + assert isinstance(output, torch.ScriptObject), ( + "outputs must be ModelOutput instances" + ) + self._additional_output_requests = additional_outputs + self._model = model.to(device=self._device) self._compute_forces = compute_forces self._compute_stress = compute_stress + self._uncertainty_threshold = uncertainty_threshold + + self._calculate_uncertainty = ( + self._energy_uq_key in self._model.capabilities().outputs + and self._model.capabilities().outputs[self._energy_uq_key].per_atom + and uncertainty_threshold is not None + ) + + if self._calculate_uncertainty: + if uncertainty_threshold <= 0.0: + raise ValueError( + f"`uncertainty_threshold` is {uncertainty_threshold} but must " + "be positive" + ) self._requested_neighbor_lists = self._model.requested_neighbor_lists() + # Precompute the outputs dict (immutable after __init__) + run_outputs: Dict[str, ModelOutput] = { + self._energy_key: ModelOutput(quantity="energy", unit="eV", per_atom=False), + } + if self._calculate_uncertainty: + run_outputs[self._energy_uq_key] = ModelOutput( + quantity="energy", unit="eV", per_atom=True + ) + if self._non_conservative: + if self._compute_forces: + run_outputs[self._nc_forces_key] = ModelOutput( + quantity="force", unit="eV/Angstrom", per_atom=True + ) + if self._compute_stress: + run_outputs[self._nc_stress_key] = ModelOutput( + quantity="pressure", unit="eV/Angstrom^3", per_atom=False + ) + run_outputs.update(self._additional_output_requests) + self._evaluation_options = ModelEvaluationOptions( length_unit="angstrom", - outputs={ - "energy": ModelOutput(quantity="energy", unit="eV", per_atom=False) - }, + outputs=run_outputs, ) + self.additional_outputs: Dict[str, TensorMap] = {} + """ + Additional outputs computed by :py:meth:`forward` are stored here. + Keys match the ``additional_outputs`` parameter to the constructor; + values are raw :py:class:`metatensor.torch.TensorMap` from the model. + """ + def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given simulation state. @@ -171,22 +309,32 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: f"model dtype {self._dtype}" ) + # Determine whether autograd is needed + do_autograd_forces = self._compute_forces and not self._non_conservative + do_autograd_stress = self._compute_stress and not self._non_conservative + # Build per-system System objects. Metatomic expects a list of System # rather than a single batched graph. systems: List[System] = [] strains: List[torch.Tensor] = [] n_systems = len(cell) + pbc = state.pbc + if isinstance(pbc, bool): + pbc = torch.tensor([pbc, pbc, pbc]) + elif not isinstance(pbc, torch.Tensor): + pbc = torch.tensor(pbc) + for sys_idx in range(n_systems): mask = state.system_idx == sys_idx sys_positions = positions[mask] sys_cell = cell[sys_idx] sys_types = atomic_nums[mask] - if self._compute_forces: + if do_autograd_forces: sys_positions = sys_positions.detach().requires_grad_(True) - if self._compute_stress: + if do_autograd_stress: strain = torch.eye( 3, device=self._device, @@ -202,7 +350,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: positions=sys_positions, types=sys_types, cell=sys_cell, - pbc=state.pbc, + pbc=pbc, ) ) @@ -213,25 +361,66 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: check_consistency=self._check_consistency, ) - # Run the model + # Run the model (evaluation options precomputed in __init__) model_outputs = self._model( systems=systems, options=self._evaluation_options, check_consistency=self._check_consistency, ) - energy_values = model_outputs["energy"].block().values + energy_values = model_outputs[self._energy_key].block().values results: Dict[str, torch.Tensor] = {} results["energy"] = energy_values.detach().squeeze(-1) - # Compute forces and/or stresses via autograd - if self._compute_forces or self._compute_stress: - grad_inputs: List[torch.Tensor] = [] + # Uncertainty warning + if self._calculate_uncertainty: + uncertainty = model_outputs[self._energy_uq_key].block().values + n_total_atoms = positions.shape[0] + if uncertainty.shape != (n_total_atoms, 1): + raise ValueError( + f"expected uncertainty shape ({n_total_atoms}, 1), " + f"got {uncertainty.shape}" + ) + threshold = self._uncertainty_threshold + if torch.any(uncertainty > threshold): + exceeded = torch.where(uncertainty.squeeze(-1) > threshold)[0] + atom_list = exceeded.tolist() + if len(atom_list) > 20: + atom_list = atom_list[:20] + suffix = f" (and {len(exceeded) - 20} more)" + else: + suffix = "" + warnings.warn( + "Some of the atomic energy uncertainties are larger than the " + f"threshold of {threshold} eV. The prediction is above the " + f"threshold for atoms {atom_list}{suffix}.", + stacklevel=2, + ) + + # Forces and stresses + if self._non_conservative: if self._compute_forces: + nc_forces = model_outputs[self._nc_forces_key].block().values.detach() + nc_forces = nc_forces.reshape(-1, 3) + # Remove spurious net force per system + for sys_idx in range(n_systems): + mask = state.system_idx == sys_idx + sys_forces = nc_forces[mask] + nc_forces[mask] = sys_forces - sys_forces.mean(dim=0, keepdim=True) + results["forces"] = nc_forces + + if self._compute_stress: + nc_stress = model_outputs[self._nc_stress_key].block().values.detach() + nc_stress = nc_stress.reshape(n_systems, 3, 3) + results["stress"] = nc_stress + + elif do_autograd_forces or do_autograd_stress: + grad_inputs: List[torch.Tensor] = [] + if do_autograd_forces: for system in systems: grad_inputs.append(system.positions) - if self._compute_stress: + if do_autograd_stress: grad_inputs.extend(strains) grads = torch.autograd.grad( @@ -240,21 +429,21 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: grad_outputs=torch.ones_like(energy_values), ) - if self._compute_forces and self._compute_stress: + if do_autograd_forces and do_autograd_stress: n_sys = len(systems) force_grads = grads[:n_sys] stress_grads = grads[n_sys:] - elif self._compute_forces: + elif do_autograd_forces: force_grads = grads stress_grads = () else: force_grads = () stress_grads = grads - if self._compute_forces: + if do_autograd_forces: results["forces"] = torch.cat([-g for g in force_grads]) - if self._compute_stress: + if do_autograd_stress: results["stress"] = torch.stack( [ g / torch.abs(torch.det(system.cell.detach())) @@ -262,6 +451,11 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: ] ) + # Store additional outputs + self.additional_outputs = {} + for name in self._additional_output_requests: + self.additional_outputs[name] = model_outputs[name] + return results diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 83428dd93..03f229fc9 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -1,7 +1,9 @@ """Tests for the MetatomicModel TorchSim wrapper. Uses the metatomic-lj-test model so that tests run without -downloading large model files. +downloading large model files. The pure-PyTorch LJ model +(``with_extension=False``) provides NC forces/stress, energy +uncertainty, and "/doubled" variants for full feature testing. """ import numpy as np @@ -10,6 +12,7 @@ import torch_sim as ts import metatomic_lj_test +from metatomic.torch import ModelOutput from metatomic_torchsim import MetatomicModel @@ -23,6 +26,7 @@ @pytest.fixture def lj_model(): + """Pure-PyTorch LJ model with NC, UQ, and variant outputs.""" return metatomic_lj_test.lennard_jones_model( atomic_type=28, cutoff=CUTOFF, @@ -34,6 +38,20 @@ def lj_model(): ) +@pytest.fixture +def lj_model_ext(): + """Extension LJ model (no NC/UQ outputs).""" + return metatomic_lj_test.lennard_jones_model( + atomic_type=28, + cutoff=CUTOFF, + sigma=SIGMA, + epsilon=EPSILON, + length_unit="Angstrom", + energy_unit="eV", + with_extension=True, + ) + + @pytest.fixture def ni_atoms(): """Create a small perturbed Ni FCC supercell.""" @@ -49,7 +67,7 @@ def ni_atoms(): @pytest.fixture def metatomic_model(lj_model): - return MetatomicModel(model=lj_model, device=DEVICE) + return MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) def test_initialization(lj_model): @@ -63,7 +81,12 @@ def test_initialization(lj_model): def test_initialization_no_forces(lj_model): """Can disable force computation.""" - model = MetatomicModel(model=lj_model, device=DEVICE, compute_forces=False) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + compute_forces=False, + uncertainty_threshold=None, + ) assert model.compute_forces is False assert model.compute_stress is True @@ -101,7 +124,9 @@ def test_forward_returns_stress(metatomic_model, ni_atoms): def test_forward_no_stress(lj_model, ni_atoms): """Stress is not returned when compute_stress=False.""" - model = MetatomicModel(model=lj_model, device=DEVICE, compute_stress=False) + model = MetatomicModel( + model=lj_model, device=DEVICE, compute_stress=False, uncertainty_threshold=None + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -112,7 +137,12 @@ def test_forward_no_stress(lj_model, ni_atoms): def test_forward_no_forces(lj_model, ni_atoms): """Forces are not returned when compute_forces=False.""" - model = MetatomicModel(model=lj_model, device=DEVICE, compute_forces=False) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + compute_forces=False, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -209,7 +239,7 @@ def test_single_atom_system(lj_model): cell=[10.0, 10.0, 10.0], pbc=True, ) - model = MetatomicModel(model=lj_model, device=DEVICE) + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) sim_state = ts.io.atoms_to_state([atoms], DEVICE, DTYPE) output = model(sim_state) @@ -221,7 +251,11 @@ def test_single_atom_system(lj_model): def test_energy_only_mode(lj_model, ni_atoms): """Model returns only energy when forces and stress are disabled.""" model = MetatomicModel( - model=lj_model, device=DEVICE, compute_forces=False, compute_stress=False + model=lj_model, + device=DEVICE, + compute_forces=False, + compute_stress=False, + uncertainty_threshold=None, ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -233,7 +267,12 @@ def test_energy_only_mode(lj_model, ni_atoms): def test_check_consistency_mode(lj_model, ni_atoms): """Model runs with consistency checking enabled.""" - model = MetatomicModel(model=lj_model, device=DEVICE, check_consistency=True) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + check_consistency=True, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -245,7 +284,9 @@ def test_check_consistency_mode(lj_model, ni_atoms): def test_forces_match_finite_difference(lj_model, ni_atoms): """Autograd forces match finite-difference gradient of energy.""" delta = 1e-4 - model = MetatomicModel(model=lj_model, device=DEVICE, compute_stress=False) + model = MetatomicModel( + model=lj_model, device=DEVICE, compute_stress=False, uncertainty_threshold=None + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) autograd_forces = output["forces"] @@ -279,3 +320,240 @@ def test_stress_is_symmetric(metatomic_model, ni_atoms): stress = output["stress"] torch.testing.assert_close(stress, stress.transpose(-2, -1), atol=1e-10, rtol=0) + + +# ---- Variants ---- + + +def test_variants_default(lj_model, ni_atoms): + """Default variant (None) selects the base energy output.""" + model = MetatomicModel( + model=lj_model, + device=DEVICE, + variants={"energy": None}, + uncertainty_threshold=None, + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "energy" in output + assert output["energy"].shape == (1,) + + +def test_variants_doubled(lj_model, ni_atoms): + """Selecting the 'doubled' variant gives 2x the base energy.""" + model_base = MetatomicModel( + model=lj_model, device=DEVICE, uncertainty_threshold=None + ) + model_doubled = MetatomicModel( + model=lj_model, + device=DEVICE, + variants={"energy": "doubled"}, + uncertainty_threshold=None, + ) + + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + e_base = model_base(sim_state)["energy"] + e_doubled = model_doubled(sim_state)["energy"] + + torch.testing.assert_close(e_doubled, 2.0 * e_base, atol=1e-10, rtol=0) + + +# ---- Uncertainty ---- + + +def test_uncertainty_warning_emitted(lj_model, ni_atoms): + """Uncertainty warning fires when atoms exceed threshold.""" + # LJ test model pseudo-uncertainty scales with system size. + # Use a very small threshold to guarantee it fires. + # filterwarnings = ["error"] converts warnings to exceptions, + # so we catch it as an error. + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=1e-10) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + with pytest.raises(UserWarning, match="uncertainties are larger"): + model(sim_state) + + +def test_uncertainty_no_warning_high_threshold(lj_model, ni_atoms): + """No warning when threshold is above all uncertainties.""" + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=1e6) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + # Should not warn -- high threshold above all uncertainty values + model(sim_state) + + +def test_uncertainty_threshold_none(lj_model, ni_atoms): + """Setting uncertainty_threshold=None disables UQ entirely.""" + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + # Should not warn -- UQ disabled + model(sim_state) + + +def test_negative_uncertainty_threshold_raises(lj_model): + """Negative uncertainty_threshold raises ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=-0.1) + + +# ---- Additional outputs ---- + + +def test_additional_outputs_empty(lj_model, ni_atoms): + """additional_outputs defaults to empty dict.""" + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + model(sim_state) + assert model.additional_outputs == {} + + +def test_additional_outputs_requested(lj_model, ni_atoms): + """Extra model outputs are stored in additional_outputs.""" + extra = { + "energy_ensemble": ModelOutput(quantity="energy", unit="eV", per_atom=True), + } + model = MetatomicModel( + model=lj_model, + device=DEVICE, + additional_outputs=extra, + uncertainty_threshold=None, + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + model(sim_state) + + assert "energy_ensemble" in model.additional_outputs + # energy_ensemble has 16 properties (ensemble members) + block = model.additional_outputs["energy_ensemble"].block() + assert block.values.shape[0] == len(ni_atoms) + + +# ---- Non-conservative ---- + + +def test_non_conservative_forces(lj_model, ni_atoms): + """NC forces are returned without autograd.""" + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "forces" in output + assert output["forces"].shape == (len(ni_atoms), 3) + # NC forces should have zero net force (mean-subtracted) + net_force = output["forces"].sum(dim=0) + torch.testing.assert_close( + net_force, torch.zeros(3, dtype=DTYPE), atol=1e-6, rtol=0 + ) + + +def test_non_conservative_stress(lj_model, ni_atoms): + """NC stress is returned with correct shape.""" + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "stress" in output + assert output["stress"].shape == (1, 3, 3) + + +def test_non_conservative_batched_forces(lj_model, ni_atoms): + """NC net-force subtraction is per-system in batched mode.""" + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) + ni_atoms_2 = ni_atoms.copy() + ni_atoms_2.positions += 0.3 * np.random.rand(*ni_atoms_2.positions.shape) + + sim_state = ts.io.atoms_to_state([ni_atoms, ni_atoms_2], DEVICE, DTYPE) + output = model(sim_state) + + n1 = len(ni_atoms) + n2 = len(ni_atoms_2) + forces = output["forces"] + assert forces.shape == (n1 + n2, 3) + + # Each system's forces should independently sum to zero + net_1 = forces[:n1].sum(dim=0) + net_2 = forces[n1:].sum(dim=0) + torch.testing.assert_close(net_1, torch.zeros(3, dtype=DTYPE), atol=1e-6, rtol=0) + torch.testing.assert_close(net_2, torch.zeros(3, dtype=DTYPE), atol=1e-6, rtol=0) + + +def test_non_conservative_missing_output_raises(lj_model_ext): + """ValueError when model lacks NC outputs.""" + with pytest.raises((ValueError, RuntimeError), match="not found"): + MetatomicModel(model=lj_model_ext, device=DEVICE, non_conservative=True) + + +def test_non_conservative_stress_only(lj_model, ni_atoms): + """NC mode with compute_forces=False returns only stress.""" + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + compute_forces=False, + uncertainty_threshold=None, + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "energy" in output + assert "forces" not in output + assert "stress" in output + assert output["stress"].shape == (1, 3, 3) + + +def test_non_conservative_with_variants(lj_model, ni_atoms): + """NC doubled variant gives different forces than base variant.""" + model_base = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) + model_doubled = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + variants={ + "energy": "doubled", + "non_conservative_forces": "doubled", + "non_conservative_stress": "doubled", + }, + ) + + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + out_base = model_base(sim_state) + out_doubled = model_doubled(sim_state) + + assert "energy" in out_doubled + assert "forces" in out_doubled + assert "stress" in out_doubled + + # Doubled energy should be 2x base + torch.testing.assert_close( + out_doubled["energy"], 2.0 * out_base["energy"], atol=1e-10, rtol=0 + ) + + +def test_additional_outputs_invalid_raises(lj_model): + """Passing non-ModelOutput values raises AssertionError.""" + with pytest.raises(AssertionError): + MetatomicModel( + model=lj_model, + device=DEVICE, + additional_outputs={"bad": "not a ModelOutput"}, + )