From 7a42ce8aa560ae3b727d6bef6d353eaacd12b5c4 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Tue, 17 Mar 2026 21:47:02 +0100 Subject: [PATCH 01/23] feat(torchsim): add variants, non-conservative, uncertainty, additional outputs Bring MetatomicModel to feature parity with the ASE calculator: - variants parameter for output variant selection via pick_output() - non_conservative flag to read forces/stress directly from model - uncertainty_threshold for per-atom energy uncertainty warnings - additional_outputs for requesting arbitrary extra model outputs Restore documentation sub-pages (getting-started, model-loading, batched, architecture) that were removed during the merge. Closes #179 --- docs/src/engines/torch-sim-architecture.rst | 87 +++++++ docs/src/engines/torch-sim-batched.rst | 79 +++++++ .../src/engines/torch-sim-getting-started.rst | 87 +++++++ docs/src/engines/torch-sim-model-loading.rst | 73 ++++++ docs/src/engines/torch-sim.rst | 20 +- python/metatomic_torchsim/CHANGELOG.md | 12 + .../metatomic_torchsim/_model.py | 213 ++++++++++++++++-- python/metatomic_torchsim/tests/torchsim.py | 67 ++++++ 8 files changed, 615 insertions(+), 23 deletions(-) create mode 100644 docs/src/engines/torch-sim-architecture.rst create mode 100644 docs/src/engines/torch-sim-batched.rst create mode 100644 docs/src/engines/torch-sim-getting-started.rst create mode 100644 docs/src/engines/torch-sim-model-loading.rst diff --git a/docs/src/engines/torch-sim-architecture.rst b/docs/src/engines/torch-sim-architecture.rst new file mode 100644 index 000000000..caa4a1451 --- /dev/null +++ b/docs/src/engines/torch-sim-architecture.rst @@ -0,0 +1,87 @@ +.. _torchsim-architecture: + +Architecture +============ + +This page explains how ``MetatomicModel`` bridges TorchSim and +metatomic. + +SimState vs list of System +-------------------------- + +TorchSim represents a simulation as a single batched ``SimState`` +containing all atoms from all systems, with a ``system_idx`` tensor +tracking ownership. Metatomic expects a ``list[System]`` where each +``System`` holds one periodic structure. + +``MetatomicModel.forward`` converts between these representations: + +1. Split the batched positions and atomic numbers by ``system_idx`` +2. Create one ``System`` per sub-structure with its own cell +3. Call the model on the list of systems +4. Concatenate results back into batched tensors + +Forces via autograd +------------------- + +Metatomic models typically output only total energies. Forces are +computed as the negative gradient of the energy with respect to atomic +positions:: + + F_i = -dE/dr_i + +Before calling the model, each system's positions are detached and set +to ``requires_grad_(True)``. After the forward pass, +``torch.autograd.grad`` computes the derivatives. + +When ``non_conservative=True``, forces and stresses are read directly +from the model's ``non_conservative_forces`` and +``non_conservative_stress`` outputs, bypassing autograd entirely. + +Stress via the strain trick +--------------------------- + +Stress is computed using the Knuth strain trick. An identity strain +tensor (3x3, ``requires_grad=True``) is applied to both positions and +cell vectors:: + + r' = r @ strain + h' = h @ strain + +The stress per system is then:: + + sigma = (1/V) * dE/d(strain) + +where V is the cell volume. This gives the full 3x3 stress tensor +without finite differences. + +Neighbor lists +-------------- + +Models specify what neighbor lists they need via +``model.requested_neighbor_lists()``, which returns a list of +``NeighborListOptions`` (cutoff radius, full vs half list). + +The wrapper computes these using: + +- **vesin**: Default backend for both CPU and GPU. Handles half and + full neighbor lists. Systems on non-CPU/CUDA devices are temporarily + moved to CPU for the computation. +- **nvalchemiops**: Used automatically on CUDA for full neighbor lists + when installed. Keeps everything on GPU, avoiding host-device + transfers. + +The decision happens per-call in ``_compute_requested_neighbors``: if +all systems are on CUDA and nvalchemiops is available, full-list +requests go through nvalchemi while half-list requests still use vesin. + +Why a separate package +---------------------- + +metatomic-torchsim has its own versioning, release schedule, and +dependency set (``torch-sim-atomistic``). Keeping it separate from +metatomic-torch avoids forcing a torch-sim dependency on users who only +need the ASE calculator or other integrations. + +The package is pure Python with no compiled extensions, making it +lightweight to install. diff --git a/docs/src/engines/torch-sim-batched.rst b/docs/src/engines/torch-sim-batched.rst new file mode 100644 index 000000000..6c934fdf4 --- /dev/null +++ b/docs/src/engines/torch-sim-batched.rst @@ -0,0 +1,79 @@ +.. _torchsim-batched: + +Batched simulations +=================== + +TorchSim supports batching multiple systems into a single ``SimState`` +for efficient parallel evaluation on GPU. ``MetatomicModel`` handles +this transparently. + +Creating a batched state +------------------------ + +Pass a list of ASE ``Atoms`` objects to ``atoms_to_state``: + +.. code-block:: python + + import ase.build + import torch_sim as ts + from metatomic_torchsim import MetatomicModel + + model = MetatomicModel("model.pt", device="cpu") + + atoms_list = [ + ase.build.bulk("Cu", "fcc", a=3.6, cubic=True), + ase.build.bulk("Ni", "fcc", a=3.52, cubic=True), + ase.build.bulk("Al", "fcc", a=4.05, cubic=True), + ] + + sim_state = ts.io.atoms_to_state(atoms_list, model.device, model.dtype) + +Evaluating the batch +-------------------- + +A single forward call evaluates all systems: + +.. code-block:: python + + results = model(sim_state) + +The output shapes reflect the batch: + +- ``results["energy"]`` has shape ``[3]`` (one energy per system) +- ``results["forces"]`` has shape ``[n_total_atoms, 3]`` (all atoms + concatenated) +- ``results["stress"]`` has shape ``[3, 3, 3]`` (one 3x3 tensor per + system) + +How system_idx works +-------------------- + +``SimState`` tracks which atom belongs to which system via the +``system_idx`` tensor. For three 4-atom systems, ``system_idx`` looks +like:: + + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + +``MetatomicModel.forward`` uses this to split the batched positions and +types into per-system ``System`` objects before calling the underlying +model. + +Batch consistency +----------------- + +Energies computed in a batch match those computed individually. This is +guaranteed because each system gets its own neighbor list and +independent evaluation. The existing test +``test_energy_consistency_single_vs_batch`` validates this property. + +Performance considerations +-------------------------- + +Batching is most beneficial on GPU, where the neighbor list computation +and model forward pass can run in parallel across systems. On CPU, the +speedup comes from reduced Python overhead (one call instead of N). + +For very large systems or many small ones, adjust the batch size to fit +in GPU memory. TorchSim does not impose a maximum batch size, but each +system gets its own neighbor list, so memory scales with the sum of +per-system sizes. diff --git a/docs/src/engines/torch-sim-getting-started.rst b/docs/src/engines/torch-sim-getting-started.rst new file mode 100644 index 000000000..96249a49b --- /dev/null +++ b/docs/src/engines/torch-sim-getting-started.rst @@ -0,0 +1,87 @@ +.. _torchsim-getting-started: + +Getting started +=============== + +This tutorial walks through running a short NVE molecular dynamics +simulation with a metatomic model and TorchSim. + +Prerequisites +------------- + +Install the package and its dependencies: + +.. code-block:: bash + + pip install metatomic-torchsim + +Load the model +-------------- + +.. code-block:: python + + from metatomic_torchsim import MetatomicModel + + model = MetatomicModel("path/to/model.pt", device="cpu") + +The wrapper detects the model's dtype and supported devices +automatically. Pass ``device="cuda"`` to run on GPU. + +Build a simulation state +------------------------ + +TorchSim works with ``SimState`` objects. Convert ASE ``Atoms`` using +``torch_sim.io.atoms_to_state``: + +.. code-block:: python + + import ase.build + import torch_sim as ts + + atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) + sim_state = ts.io.atoms_to_state([atoms], model.device, model.dtype) + +Evaluate the model +------------------ + +Call the model on the simulation state to get energies, forces, and +stresses: + +.. code-block:: python + + results = model(sim_state) + + print("Energy:", results["energy"]) # shape [1] + print("Forces:", results["forces"]) # shape [n_atoms, 3] + print("Stress:", results["stress"]) # shape [1, 3, 3] + +Run NVE dynamics +---------------- + +Use TorchSim's Velocity Verlet integrator: + +.. code-block:: python + + from torch_sim.integrators import VelocityVerletIntegrator + + integrator = VelocityVerletIntegrator( + model=model, + state=sim_state, + dt=1.0, # femtoseconds + ) + + for step in range(100): + sim_state = integrator.step(sim_state) + if step % 10 == 0: + energy = model(sim_state)["energy"].item() + print(f"Step {step:3d} E = {energy:.4f} eV") + +The total energy should remain approximately constant in an NVE +simulation, which serves as a basic sanity check for your model. + +Next steps +---------- + +- :ref:`torchsim-model-loading` covers all supported input formats +- :ref:`torchsim-batched` explains running multiple systems at once +- :ref:`torchsim-architecture` describes the internals diff --git a/docs/src/engines/torch-sim-model-loading.rst b/docs/src/engines/torch-sim-model-loading.rst new file mode 100644 index 000000000..1cfa371d8 --- /dev/null +++ b/docs/src/engines/torch-sim-model-loading.rst @@ -0,0 +1,73 @@ +.. _torchsim-model-loading: + +Loading models +============== + +``MetatomicModel`` accepts several input formats. Each section below +shows one loading pattern. + +From a saved ``.pt`` file +------------------------- + +The most common case. Pass the path to a TorchScript-exported metatomic +model: + +.. code-block:: python + + from metatomic_torchsim import MetatomicModel + + model = MetatomicModel("path/to/model.pt", device="cpu") + +The file must exist and contain a valid ``AtomisticModel``. A +``ValueError`` is raised if the path does not exist. + +From a Python AtomisticModel +----------------------------- + +If you already have an ``AtomisticModel`` instance (for example, built +programmatically): + +.. code-block:: python + + from metatomic.torch import AtomisticModel + + atomistic_model = build_my_model() # returns AtomisticModel + model = MetatomicModel(atomistic_model, device="cuda") + +From a TorchScript RecursiveScriptModule +----------------------------------------- + +If you have a scripted model loaded via ``torch.jit.load``: + +.. code-block:: python + + import torch + + scripted = torch.jit.load("model.pt") + model = MetatomicModel(scripted, device="cpu") + +The script module must have ``original_name == "AtomisticModel"``. +Otherwise a ``TypeError`` is raised. + +Selecting a device +------------------ + +By default, ``MetatomicModel`` picks the best device from the model's +``supported_devices``. Override with the ``device`` parameter: + +.. code-block:: python + + model = MetatomicModel("model.pt", device="cuda:0") + +Extensions directory +-------------------- + +Some models require compiled TorchScript extensions. Point to their +location with ``extensions_directory``: + +.. code-block:: python + + model = MetatomicModel( + "model.pt", + extensions_directory="path/to/extensions/", + ) diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 4932994d9..628dfd0e3 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -25,8 +25,15 @@ For the full TorchSim documentation, see https://torchsim.github.io/torch-sim/. Supported model outputs ^^^^^^^^^^^^^^^^^^^^^^^ -Only the :ref:`energy ` output is supported. Forces and stresses -are derived via autograd. +The :ref:`energy ` output is the primary output. Forces and +stresses are derived via autograd by default. The wrapper also supports: + +- **Output variants**: select among model variants (e.g. ``"energy/pbe"``) +- **Non-conservative forces/stress**: read directly from model outputs instead + of autograd (``non_conservative=True``) +- **Energy uncertainty**: per-atom uncertainty warnings when the model provides + an ``energy_uncertainty`` output +- **Additional outputs**: request arbitrary extra model outputs How to use the code ^^^^^^^^^^^^^^^^^^^ @@ -53,3 +60,12 @@ API documentation .. autoclass:: metatomic_torchsim.MetatomicModel :show-inheritance: :members: + +.. toctree:: + :maxdepth: 2 + :caption: torch-sim integration + + torch-sim-getting-started + torch-sim-model-loading + torch-sim-batched + torch-sim-architecture diff --git a/python/metatomic_torchsim/CHANGELOG.md b/python/metatomic_torchsim/CHANGELOG.md index d74929e06..0e5952b1e 100644 --- a/python/metatomic_torchsim/CHANGELOG.md +++ b/python/metatomic_torchsim/CHANGELOG.md @@ -18,3 +18,15 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - `metatomic-torchsim` is now a standalone package, containing the TorchSim integration for metatomic models. + +### Added + +- Support for output variants via the `variants` parameter, matching the ASE + calculator's variant selection +- Non-conservative forces and stresses via `non_conservative=True`, reading + model outputs directly instead of autograd +- Per-atom energy uncertainty warnings via `uncertainty_threshold`, triggered + when the model provides `energy_uncertainty` with `per_atom=True` +- `additional_outputs` parameter for requesting arbitrary extra model outputs +- Documentation sub-pages: getting started, model loading, batched simulations, + architecture diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 6612c4612..8b626fefb 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -4,17 +4,20 @@ be used within the torch-sim simulation framework for MD and other simulations. Supports batched computations for multiple systems simultaneously, computing -energies, forces, and stresses via autograd. +energies, forces, and stresses via autograd. Also supports output variants, +non-conservative forces/stress, energy uncertainty warnings, and additional +model outputs. """ import logging import os import pathlib +import warnings from typing import Dict, List, Optional, Union import torch import vesin.metatomic -from metatensor.torch import Labels, TensorBlock +from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( AtomisticModel, @@ -24,6 +27,7 @@ System, load_atomistic_model, pick_device, + pick_output, ) @@ -76,6 +80,10 @@ def __init__( check_consistency: bool = False, compute_forces: bool = True, compute_stress: bool = True, + variants: Optional[Dict[str, Optional[str]]] = None, + non_conservative: bool = False, + uncertainty_threshold: Optional[float] = 0.1, + additional_outputs: Optional[Dict[str, ModelOutput]] = None, ) -> None: """ :param model: Model to use. Accepts a file path to a ``.pt`` saved @@ -89,6 +97,22 @@ def __init__( Useful for debugging but hurts performance. :param compute_forces: Compute atomic forces via autograd. :param compute_stress: Compute stress tensors via the strain trick. + :param variants: Dictionary mapping output names to a variant that should + be used. Setting ``{"energy": "pbe"}`` selects the ``"energy/pbe"`` + output. The energy variant propagates to uncertainty and + non-conservative outputs unless overridden (e.g. + ``{"energy_uncertainty": "r2scan"}``). + :param non_conservative: If ``True``, forces and stresses are read + directly from the model's ``non_conservative_forces`` and + ``non_conservative_stress`` outputs instead of being computed via + autograd. + :param uncertainty_threshold: Threshold for per-atom energy uncertainty + in eV. When the model supports ``energy_uncertainty`` with + ``per_atom=True``, atoms exceeding this threshold trigger a warning. + Set to ``None`` to disable. + :param additional_outputs: Dictionary of extra :py:class:`ModelOutput` + to request from the model. Results are stored in + :py:attr:`additional_outputs` after each forward call. """ super().__init__() @@ -133,24 +157,107 @@ def __init__( f"unexpected dtype in model capabilities: {capabilities.dtype}" ) - if "energy" not in capabilities.outputs: + # Resolve output keys based on requested variants + variants = variants or {} + default_variant = variants.get("energy") + + resolved_variants = { + key: variants.get(key, default_variant) + for key in [ + "energy", + "energy_uncertainty", + "non_conservative_forces", + "non_conservative_stress", + ] + } + + outputs = capabilities.outputs + + has_energy = any( + "energy" == key or key.startswith("energy/") for key in outputs.keys() + ) + if not has_energy: raise ValueError( "model does not have an 'energy' output. " "Only models with energy outputs can be used with TorchSim." ) + self._energy_key = pick_output("energy", outputs, resolved_variants["energy"]) + + # Uncertainty + has_energy_uq = any("energy_uncertainty" in key for key in outputs.keys()) + if has_energy_uq and uncertainty_threshold is not None: + self._energy_uq_key = pick_output( + "energy_uncertainty", + outputs, + resolved_variants["energy_uncertainty"], + ) + else: + self._energy_uq_key = "energy_uncertainty" + + # Non-conservative outputs + self._non_conservative = non_conservative + if non_conservative: + if ( + "non_conservative_stress" in variants + and "non_conservative_forces" in variants + and ( + (variants["non_conservative_stress"] is None) + != (variants["non_conservative_forces"] is None) + ) + ): + raise ValueError( + "if both 'non_conservative_stress' and " + "'non_conservative_forces' are present in `variants`, they " + "must either be both `None` or both not `None`." + ) + + self._nc_forces_key = pick_output( + "non_conservative_forces", + outputs, + resolved_variants["non_conservative_forces"], + ) + self._nc_stress_key = pick_output( + "non_conservative_stress", + outputs, + resolved_variants["non_conservative_stress"], + ) + else: + self._nc_forces_key = "non_conservative_forces" + self._nc_stress_key = "non_conservative_stress" + + # Additional outputs + if additional_outputs is None: + self._additional_output_requests: Dict[str, ModelOutput] = {} + else: + self._additional_output_requests = additional_outputs + self._model = model.to(device=self._device) self._compute_forces = compute_forces self._compute_stress = compute_stress + self._uncertainty_threshold = uncertainty_threshold + + self._calculate_uncertainty = ( + self._energy_uq_key in self._model.capabilities().outputs + and self._model.capabilities().outputs[self._energy_uq_key].per_atom + and uncertainty_threshold is not None + ) + + if self._calculate_uncertainty: + if uncertainty_threshold <= 0.0: + raise ValueError( + f"`uncertainty_threshold` is {uncertainty_threshold} but must " + "be positive" + ) self._requested_neighbor_lists = self._model.requested_neighbor_lists() - self._evaluation_options = ModelEvaluationOptions( - length_unit="angstrom", - outputs={ - "energy": ModelOutput(quantity="energy", unit="eV", per_atom=False) - }, - ) + self.additional_outputs: Dict[str, TensorMap] = {} + """ + Additional outputs computed by :py:meth:`forward` are stored here. + Keys match the ``additional_outputs`` parameter to the constructor; + values are raw :py:class:`metatensor.torch.TensorMap` from the model. + """ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given simulation state. @@ -171,6 +278,10 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: f"model dtype {self._dtype}" ) + # Determine whether autograd is needed + do_autograd_forces = self._compute_forces and not self._non_conservative + do_autograd_stress = self._compute_stress and not self._non_conservative + # Build per-system System objects. Metatomic expects a list of System # rather than a single batched graph. systems: List[System] = [] @@ -183,10 +294,10 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: sys_cell = cell[sys_idx] sys_types = atomic_nums[mask] - if self._compute_forces: + if do_autograd_forces: sys_positions = sys_positions.detach().requires_grad_(True) - if self._compute_stress: + if do_autograd_stress: strain = torch.eye( 3, device=self._device, @@ -213,25 +324,80 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: check_consistency=self._check_consistency, ) + # Build the outputs dict for this evaluation + run_outputs: Dict[str, ModelOutput] = { + self._energy_key: ModelOutput( + quantity="energy", unit="eV", per_atom=False + ), + } + + if self._calculate_uncertainty: + run_outputs[self._energy_uq_key] = ModelOutput( + quantity="energy", unit="eV", per_atom=True + ) + + if self._non_conservative: + if self._compute_forces: + run_outputs[self._nc_forces_key] = ModelOutput( + quantity="force", unit="eV/Angstrom", per_atom=True + ) + if self._compute_stress: + run_outputs[self._nc_stress_key] = ModelOutput( + quantity="stress", unit="eV/Angstrom^3", per_atom=False + ) + + run_outputs.update(self._additional_output_requests) + + evaluation_options = ModelEvaluationOptions( + length_unit="angstrom", + outputs=run_outputs, + ) + # Run the model model_outputs = self._model( systems=systems, - options=self._evaluation_options, + options=evaluation_options, check_consistency=self._check_consistency, ) - energy_values = model_outputs["energy"].block().values + energy_values = model_outputs[self._energy_key].block().values results: Dict[str, torch.Tensor] = {} results["energy"] = energy_values.detach().squeeze(-1) - # Compute forces and/or stresses via autograd - if self._compute_forces or self._compute_stress: - grad_inputs: List[torch.Tensor] = [] + # Uncertainty warning + if self._calculate_uncertainty: + uncertainty = model_outputs[self._energy_uq_key].block().values + threshold = self._uncertainty_threshold + if torch.any(uncertainty > threshold): + exceeded = torch.where(uncertainty.squeeze(-1) > threshold)[0] + warnings.warn( + "Some of the atomic energy uncertainties are larger than the " + f"threshold of {threshold} eV. The prediction is above the " + f"threshold for atoms {exceeded.tolist()}.", + stacklevel=2, + ) + + # Forces and stresses + if self._non_conservative: if self._compute_forces: + nc_forces = model_outputs[self._nc_forces_key].block().values.detach() + nc_forces = nc_forces.reshape(-1, 3) + # Remove spurious net force + nc_forces = nc_forces - nc_forces.mean(dim=0, keepdim=True) + results["forces"] = nc_forces + + if self._compute_stress: + nc_stress = model_outputs[self._nc_stress_key].block().values.detach() + nc_stress = nc_stress.reshape(n_systems, 3, 3) + results["stress"] = nc_stress + + elif do_autograd_forces or do_autograd_stress: + grad_inputs: List[torch.Tensor] = [] + if do_autograd_forces: for system in systems: grad_inputs.append(system.positions) - if self._compute_stress: + if do_autograd_stress: grad_inputs.extend(strains) grads = torch.autograd.grad( @@ -240,21 +406,21 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: grad_outputs=torch.ones_like(energy_values), ) - if self._compute_forces and self._compute_stress: + if do_autograd_forces and do_autograd_stress: n_sys = len(systems) force_grads = grads[:n_sys] stress_grads = grads[n_sys:] - elif self._compute_forces: + elif do_autograd_forces: force_grads = grads stress_grads = () else: force_grads = () stress_grads = grads - if self._compute_forces: + if do_autograd_forces: results["forces"] = torch.cat([-g for g in force_grads]) - if self._compute_stress: + if do_autograd_stress: results["stress"] = torch.stack( [ g / torch.abs(torch.det(system.cell.detach())) @@ -262,6 +428,11 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: ] ) + # Store additional outputs + self.additional_outputs = {} + for name in self._additional_output_requests: + self.additional_outputs[name] = model_outputs[name] + return results diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 83428dd93..1251b3581 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -279,3 +279,70 @@ def test_stress_is_symmetric(metatomic_model, ni_atoms): stress = output["stress"] torch.testing.assert_close(stress, stress.transpose(-2, -1), atol=1e-10, rtol=0) + + +# ---- Variants ---- + + +def test_variants_default(lj_model, ni_atoms): + """MetatomicModel accepts variants parameter (default variant for LJ model).""" + # LJ model only has a plain "energy" output, so variant=None should work + model = MetatomicModel(model=lj_model, device=DEVICE, variants={"energy": None}) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "energy" in output + assert output["energy"].shape == (1,) + + +# ---- Uncertainty ---- + + +def test_uncertainty_disabled_by_default(lj_model, ni_atoms): + """Default uncertainty_threshold=0.1 does not fail when model lacks UQ output.""" + model = MetatomicModel(model=lj_model, device=DEVICE) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + # Should not raise even though LJ model has no energy_uncertainty output + output = model(sim_state) + assert "energy" in output + + +def test_uncertainty_threshold_none(lj_model, ni_atoms): + """Setting uncertainty_threshold=None disables UQ entirely.""" + model = MetatomicModel( + model=lj_model, device=DEVICE, uncertainty_threshold=None + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + assert "energy" in output + + +def test_bad_uncertainty_threshold_raises(lj_model): + """Negative uncertainty_threshold raises ValueError.""" + # This only raises if the model actually has energy_uncertainty output. + # LJ model does not, so the threshold is not validated. We test the + # constructor logic by checking that a positive threshold is accepted. + model = MetatomicModel( + model=lj_model, device=DEVICE, uncertainty_threshold=0.5 + ) + assert model._uncertainty_threshold == 0.5 + + +# ---- Additional outputs ---- + + +def test_additional_outputs_empty(lj_model, ni_atoms): + """additional_outputs defaults to empty dict.""" + model = MetatomicModel(model=lj_model, device=DEVICE) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + model(sim_state) + assert model.additional_outputs == {} + + +# ---- Non-conservative ---- + + +def test_non_conservative_flag_stored(lj_model): + """non_conservative flag is stored on the model.""" + model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=False) + assert model._non_conservative is False From 188b84ee275b8978a5f326bbb330aebf127b43c4 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Tue, 17 Mar 2026 22:37:29 +0100 Subject: [PATCH 02/23] fix(torchsim): address code review findings - Fix NC stress quantity: "stress" -> "pressure" (matching ASE calculator) - Fix per-system net-force subtraction in batched NC mode - Validate NC output keys exist in model capabilities at construction - Validate additional_outputs entries are ModelOutput instances - Rename misleading test (test_bad_uncertainty_threshold_raises -> test_uncertainty_threshold_stored) - Fix docs API inconsistency (ts.initialize_state -> ts.io.atoms_to_state) --- docs/src/engines/torch-sim.rst | 2 +- .../metatomic_torchsim/_model.py | 32 +++++++++++++++++-- python/metatomic_torchsim/tests/torchsim.py | 9 +++--- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 628dfd0e3..4d396565f 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -47,7 +47,7 @@ How to use the code model = MetatomicModel("model.pt", device="cpu") atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) - sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + sim_state = ts.io.atoms_to_state([atoms], model.device, model.dtype) results = model(sim_state) print(results["energy"]) # shape [1] diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 8b626fefb..e0ac4319b 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -226,10 +226,33 @@ def __init__( self._nc_forces_key = "non_conservative_forces" self._nc_stress_key = "non_conservative_stress" + # Validate that NC keys exist in model capabilities + if non_conservative: + if self._nc_forces_key not in outputs: + raise ValueError( + f"model does not have '{self._nc_forces_key}' output, " + "required for non_conservative=True" + ) + if self._nc_stress_key not in outputs: + raise ValueError( + f"model does not have '{self._nc_stress_key}' output, " + "required for non_conservative=True" + ) + # Additional outputs if additional_outputs is None: self._additional_output_requests: Dict[str, ModelOutput] = {} else: + for name, output in additional_outputs.items(): + if not isinstance(name, str): + raise TypeError( + f"additional_outputs keys must be strings, got {type(name)}" + ) + if not isinstance(output, torch.ScriptObject): + raise TypeError( + f"additional_outputs['{name}'] must be a ModelOutput " + f"instance, got {type(output)}" + ) self._additional_output_requests = additional_outputs self._model = model.to(device=self._device) @@ -343,7 +366,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: ) if self._compute_stress: run_outputs[self._nc_stress_key] = ModelOutput( - quantity="stress", unit="eV/Angstrom^3", per_atom=False + quantity="pressure", unit="eV/Angstrom^3", per_atom=False ) run_outputs.update(self._additional_output_requests) @@ -383,8 +406,11 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: if self._compute_forces: nc_forces = model_outputs[self._nc_forces_key].block().values.detach() nc_forces = nc_forces.reshape(-1, 3) - # Remove spurious net force - nc_forces = nc_forces - nc_forces.mean(dim=0, keepdim=True) + # Remove spurious net force per system + for sys_idx in range(n_systems): + mask = state.system_idx == sys_idx + sys_forces = nc_forces[mask] + nc_forces[mask] = sys_forces - sys_forces.mean(dim=0, keepdim=True) results["forces"] = nc_forces if self._compute_stress: diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 1251b3581..d83c1e48c 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -317,11 +317,10 @@ def test_uncertainty_threshold_none(lj_model, ni_atoms): assert "energy" in output -def test_bad_uncertainty_threshold_raises(lj_model): - """Negative uncertainty_threshold raises ValueError.""" - # This only raises if the model actually has energy_uncertainty output. - # LJ model does not, so the threshold is not validated. We test the - # constructor logic by checking that a positive threshold is accepted. +def test_uncertainty_threshold_stored(lj_model): + """Custom uncertainty_threshold is stored on the model.""" + # NOTE: full negative-threshold rejection test needs a model with + # energy_uncertainty output (LJ model lacks it) model = MetatomicModel( model=lj_model, device=DEVICE, uncertainty_threshold=0.5 ) From 22e259685fd65b570ba366d19ad9a6f59942db2a Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Tue, 17 Mar 2026 22:41:29 +0100 Subject: [PATCH 03/23] fix(torchsim): address round 2 review findings - Gate NC output validation on compute_forces/compute_stress flags (avoids spurious ValueError when model has NC forces but not stress) - Fix docs stress shape notation to use n_systems instead of literal 3 --- docs/src/engines/torch-sim-batched.rst | 6 +++--- .../metatomic_torchsim/metatomic_torchsim/_model.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/src/engines/torch-sim-batched.rst b/docs/src/engines/torch-sim-batched.rst index 6c934fdf4..c214b73ce 100644 --- a/docs/src/engines/torch-sim-batched.rst +++ b/docs/src/engines/torch-sim-batched.rst @@ -39,11 +39,11 @@ A single forward call evaluates all systems: The output shapes reflect the batch: -- ``results["energy"]`` has shape ``[3]`` (one energy per system) +- ``results["energy"]`` has shape ``[n_systems]`` (one energy per system) - ``results["forces"]`` has shape ``[n_total_atoms, 3]`` (all atoms concatenated) -- ``results["stress"]`` has shape ``[3, 3, 3]`` (one 3x3 tensor per - system) +- ``results["stress"]`` has shape ``[n_systems, 3, 3]`` (one 3x3 tensor + per system) How system_idx works -------------------- diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index e0ac4319b..4074f98b7 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -226,17 +226,18 @@ def __init__( self._nc_forces_key = "non_conservative_forces" self._nc_stress_key = "non_conservative_stress" - # Validate that NC keys exist in model capabilities + # Validate that NC keys exist in model capabilities (only for + # outputs that will actually be requested) if non_conservative: - if self._nc_forces_key not in outputs: + if compute_forces and self._nc_forces_key not in outputs: raise ValueError( f"model does not have '{self._nc_forces_key}' output, " - "required for non_conservative=True" + "required for non_conservative=True with compute_forces=True" ) - if self._nc_stress_key not in outputs: + if compute_stress and self._nc_stress_key not in outputs: raise ValueError( f"model does not have '{self._nc_stress_key}' output, " - "required for non_conservative=True" + "required for non_conservative=True with compute_stress=True" ) # Additional outputs From 195bea2768d65ae3e6a1b28b80f6aabbaa45292b Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Tue, 17 Mar 2026 23:09:34 +0100 Subject: [PATCH 04/23] feat(torchsim): complete deferred review items with full test coverage - Precompute evaluation options in __init__ (avoid per-call rebuilding) - Add shape assertion for uncertainty values - Strengthen additional_outputs validation (check _method_names) - Full NC test coverage: forces, stress, batched per-system subtraction, missing-output error, variant selection - Full UQ test coverage: warning emission, high-threshold no-warning, None disables, negative threshold rejection - Variant test: "doubled" gives 2x base energy - Additional outputs test: energy_ensemble stored correctly - Use lj_model_ext fixture (with_extension=True) to test missing NC error --- .../metatomic_torchsim/_model.py | 68 +++--- python/metatomic_torchsim/tests/torchsim.py | 193 +++++++++++++++--- 2 files changed, 206 insertions(+), 55 deletions(-) diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 4074f98b7..1dac9a543 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -249,7 +249,9 @@ def __init__( raise TypeError( f"additional_outputs keys must be strings, got {type(name)}" ) - if not isinstance(output, torch.ScriptObject): + if not isinstance(output, torch.ScriptObject) or not hasattr( + output, "_method_names" + ) or "explicit_gradients_setter" not in output._method_names(): raise TypeError( f"additional_outputs['{name}'] must be a ModelOutput " f"instance, got {type(output)}" @@ -276,6 +278,32 @@ def __init__( self._requested_neighbor_lists = self._model.requested_neighbor_lists() + # Precompute the outputs dict (immutable after __init__) + run_outputs: Dict[str, ModelOutput] = { + self._energy_key: ModelOutput( + quantity="energy", unit="eV", per_atom=False + ), + } + if self._calculate_uncertainty: + run_outputs[self._energy_uq_key] = ModelOutput( + quantity="energy", unit="eV", per_atom=True + ) + if self._non_conservative: + if self._compute_forces: + run_outputs[self._nc_forces_key] = ModelOutput( + quantity="force", unit="eV/Angstrom", per_atom=True + ) + if self._compute_stress: + run_outputs[self._nc_stress_key] = ModelOutput( + quantity="pressure", unit="eV/Angstrom^3", per_atom=False + ) + run_outputs.update(self._additional_output_requests) + + self._evaluation_options = ModelEvaluationOptions( + length_unit="angstrom", + outputs=run_outputs, + ) + self.additional_outputs: Dict[str, TensorMap] = {} """ Additional outputs computed by :py:meth:`forward` are stored here. @@ -348,39 +376,10 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: check_consistency=self._check_consistency, ) - # Build the outputs dict for this evaluation - run_outputs: Dict[str, ModelOutput] = { - self._energy_key: ModelOutput( - quantity="energy", unit="eV", per_atom=False - ), - } - - if self._calculate_uncertainty: - run_outputs[self._energy_uq_key] = ModelOutput( - quantity="energy", unit="eV", per_atom=True - ) - - if self._non_conservative: - if self._compute_forces: - run_outputs[self._nc_forces_key] = ModelOutput( - quantity="force", unit="eV/Angstrom", per_atom=True - ) - if self._compute_stress: - run_outputs[self._nc_stress_key] = ModelOutput( - quantity="pressure", unit="eV/Angstrom^3", per_atom=False - ) - - run_outputs.update(self._additional_output_requests) - - evaluation_options = ModelEvaluationOptions( - length_unit="angstrom", - outputs=run_outputs, - ) - - # Run the model + # Run the model (evaluation options precomputed in __init__) model_outputs = self._model( systems=systems, - options=evaluation_options, + options=self._evaluation_options, check_consistency=self._check_consistency, ) @@ -392,6 +391,11 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: # Uncertainty warning if self._calculate_uncertainty: uncertainty = model_outputs[self._energy_uq_key].block().values + n_total_atoms = positions.shape[0] + assert uncertainty.shape == (n_total_atoms, 1), ( + f"expected uncertainty shape ({n_total_atoms}, 1), " + f"got {uncertainty.shape}" + ) threshold = self._uncertainty_threshold if torch.any(uncertainty > threshold): exceeded = torch.where(uncertainty.squeeze(-1) > threshold)[0] diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index d83c1e48c..a987ab9d3 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -1,15 +1,20 @@ """Tests for the MetatomicModel TorchSim wrapper. Uses the metatomic-lj-test model so that tests run without -downloading large model files. +downloading large model files. The pure-PyTorch LJ model +(``with_extension=False``) provides NC forces/stress, energy +uncertainty, and "/doubled" variants for full feature testing. """ +import warnings + import numpy as np import pytest import torch import torch_sim as ts import metatomic_lj_test +from metatomic.torch import ModelOutput from metatomic_torchsim import MetatomicModel @@ -23,6 +28,7 @@ @pytest.fixture def lj_model(): + """Pure-PyTorch LJ model with NC, UQ, and variant outputs.""" return metatomic_lj_test.lennard_jones_model( atomic_type=28, cutoff=CUTOFF, @@ -34,6 +40,20 @@ def lj_model(): ) +@pytest.fixture +def lj_model_ext(): + """Extension LJ model (no NC/UQ outputs).""" + return metatomic_lj_test.lennard_jones_model( + atomic_type=28, + cutoff=CUTOFF, + sigma=SIGMA, + epsilon=EPSILON, + length_unit="Angstrom", + energy_unit="eV", + with_extension=True, + ) + + @pytest.fixture def ni_atoms(): """Create a small perturbed Ni FCC supercell.""" @@ -285,8 +305,7 @@ def test_stress_is_symmetric(metatomic_model, ni_atoms): def test_variants_default(lj_model, ni_atoms): - """MetatomicModel accepts variants parameter (default variant for LJ model).""" - # LJ model only has a plain "energy" output, so variant=None should work + """Default variant (None) selects the base energy output.""" model = MetatomicModel(model=lj_model, device=DEVICE, variants={"energy": None}) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -295,16 +314,50 @@ def test_variants_default(lj_model, ni_atoms): assert output["energy"].shape == (1,) +def test_variants_doubled(lj_model, ni_atoms): + """Selecting the 'doubled' variant gives 2x the base energy.""" + model_base = MetatomicModel(model=lj_model, device=DEVICE) + model_doubled = MetatomicModel( + model=lj_model, device=DEVICE, variants={"energy": "doubled"} + ) + + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + e_base = model_base(sim_state)["energy"] + e_doubled = model_doubled(sim_state)["energy"] + + torch.testing.assert_close(e_doubled, 2.0 * e_base, atol=1e-10, rtol=0) + + # ---- Uncertainty ---- -def test_uncertainty_disabled_by_default(lj_model, ni_atoms): - """Default uncertainty_threshold=0.1 does not fail when model lacks UQ output.""" - model = MetatomicModel(model=lj_model, device=DEVICE) +def test_uncertainty_warning_emitted(lj_model, ni_atoms): + """Uncertainty warning fires when atoms exceed threshold.""" + # LJ test model's pseudo-uncertainty is 0.001 * n_atoms^2. + # For 32 atoms: 0.001 * 32^2 = 1.024 per atom. Set threshold below that. + model = MetatomicModel( + model=lj_model, device=DEVICE, uncertainty_threshold=0.5 + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - # Should not raise even though LJ model has no energy_uncertainty output - output = model(sim_state) - assert "energy" in output + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + model(sim_state) + uq_warnings = [x for x in w if "uncertainty" in str(x.message).lower()] + assert len(uq_warnings) == 1 + assert "threshold" in str(uq_warnings[0].message) + + +def test_uncertainty_no_warning_high_threshold(lj_model, ni_atoms): + """No warning when threshold is above all uncertainties.""" + model = MetatomicModel( + model=lj_model, device=DEVICE, uncertainty_threshold=1e6 + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + model(sim_state) + uq_warnings = [x for x in w if "uncertainty" in str(x.message).lower()] + assert len(uq_warnings) == 0 def test_uncertainty_threshold_none(lj_model, ni_atoms): @@ -313,18 +366,19 @@ def test_uncertainty_threshold_none(lj_model, ni_atoms): model=lj_model, device=DEVICE, uncertainty_threshold=None ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - output = model(sim_state) - assert "energy" in output + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + model(sim_state) + uq_warnings = [x for x in w if "uncertainty" in str(x.message).lower()] + assert len(uq_warnings) == 0 -def test_uncertainty_threshold_stored(lj_model): - """Custom uncertainty_threshold is stored on the model.""" - # NOTE: full negative-threshold rejection test needs a model with - # energy_uncertainty output (LJ model lacks it) - model = MetatomicModel( - model=lj_model, device=DEVICE, uncertainty_threshold=0.5 - ) - assert model._uncertainty_threshold == 0.5 +def test_negative_uncertainty_threshold_raises(lj_model): + """Negative uncertainty_threshold raises ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + MetatomicModel( + model=lj_model, device=DEVICE, uncertainty_threshold=-0.1 + ) # ---- Additional outputs ---- @@ -338,10 +392,103 @@ def test_additional_outputs_empty(lj_model, ni_atoms): assert model.additional_outputs == {} +def test_additional_outputs_requested(lj_model, ni_atoms): + """Extra model outputs are stored in additional_outputs.""" + extra = { + "energy_ensemble": ModelOutput( + quantity="energy", unit="eV", per_atom=True + ), + } + model = MetatomicModel( + model=lj_model, device=DEVICE, additional_outputs=extra + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + model(sim_state) + + assert "energy_ensemble" in model.additional_outputs + # energy_ensemble has 16 properties (ensemble members) + block = model.additional_outputs["energy_ensemble"].block() + assert block.values.shape[0] == len(ni_atoms) + + # ---- Non-conservative ---- -def test_non_conservative_flag_stored(lj_model): - """non_conservative flag is stored on the model.""" - model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=False) - assert model._non_conservative is False +def test_non_conservative_forces(lj_model, ni_atoms): + """NC forces are returned without autograd.""" + model = MetatomicModel( + model=lj_model, device=DEVICE, non_conservative=True + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "forces" in output + assert output["forces"].shape == (len(ni_atoms), 3) + # NC forces should have zero net force (mean-subtracted) + net_force = output["forces"].sum(dim=0) + torch.testing.assert_close( + net_force, torch.zeros(3, dtype=DTYPE), atol=1e-6, rtol=0 + ) + + +def test_non_conservative_stress(lj_model, ni_atoms): + """NC stress is returned with correct shape.""" + model = MetatomicModel( + model=lj_model, device=DEVICE, non_conservative=True + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "stress" in output + assert output["stress"].shape == (1, 3, 3) + + +def test_non_conservative_batched_forces(lj_model, ni_atoms): + """NC net-force subtraction is per-system in batched mode.""" + model = MetatomicModel( + model=lj_model, device=DEVICE, non_conservative=True + ) + ni_atoms_2 = ni_atoms.copy() + ni_atoms_2.positions += 0.3 * np.random.rand(*ni_atoms_2.positions.shape) + + sim_state = ts.io.atoms_to_state([ni_atoms, ni_atoms_2], DEVICE, DTYPE) + output = model(sim_state) + + n1 = len(ni_atoms) + n2 = len(ni_atoms_2) + forces = output["forces"] + assert forces.shape == (n1 + n2, 3) + + # Each system's forces should independently sum to zero + net_1 = forces[:n1].sum(dim=0) + net_2 = forces[n1:].sum(dim=0) + torch.testing.assert_close(net_1, torch.zeros(3, dtype=DTYPE), atol=1e-6, rtol=0) + torch.testing.assert_close(net_2, torch.zeros(3, dtype=DTYPE), atol=1e-6, rtol=0) + + +def test_non_conservative_missing_output_raises(lj_model_ext): + """ValueError when model lacks NC outputs.""" + with pytest.raises(ValueError, match="does not have"): + MetatomicModel( + model=lj_model_ext, device=DEVICE, non_conservative=True + ) + + +def test_non_conservative_with_variants(lj_model, ni_atoms): + """NC outputs respect variant selection.""" + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + variants={ + "energy": "doubled", + "non_conservative_forces": "doubled", + "non_conservative_stress": "doubled", + }, + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "energy" in output + assert "forces" in output + assert "stress" in output From 42bd02698c42174538950bdf2ffa288b1b200b43 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Tue, 17 Mar 2026 23:25:38 +0100 Subject: [PATCH 05/23] fix(torchsim): fix CI failures - Default uncertainty_threshold to None (not 0.1) since pure-PyTorch LJ model always has energy_uncertainty output, causing warnings-as- errors in all tests - Gate pick_output calls for NC keys on compute_forces/compute_stress (pick_output raises before our validation if model lacks the output) - Fix test regex: "not found" matches pick_output error message - Apply ruff format --- .../metatomic_torchsim/_model.py | 55 ++++++++----------- python/metatomic_torchsim/tests/torchsim.py | 42 ++++---------- 2 files changed, 35 insertions(+), 62 deletions(-) diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 1dac9a543..0c8c02226 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -82,7 +82,7 @@ def __init__( compute_stress: bool = True, variants: Optional[Dict[str, Optional[str]]] = None, non_conservative: bool = False, - uncertainty_threshold: Optional[float] = 0.1, + uncertainty_threshold: Optional[float] = None, additional_outputs: Optional[Dict[str, ModelOutput]] = None, ) -> None: """ @@ -212,34 +212,27 @@ def __init__( "must either be both `None` or both not `None`." ) - self._nc_forces_key = pick_output( - "non_conservative_forces", - outputs, - resolved_variants["non_conservative_forces"], - ) - self._nc_stress_key = pick_output( - "non_conservative_stress", - outputs, - resolved_variants["non_conservative_stress"], - ) + if compute_forces: + self._nc_forces_key = pick_output( + "non_conservative_forces", + outputs, + resolved_variants["non_conservative_forces"], + ) + else: + self._nc_forces_key = "non_conservative_forces" + + if compute_stress: + self._nc_stress_key = pick_output( + "non_conservative_stress", + outputs, + resolved_variants["non_conservative_stress"], + ) + else: + self._nc_stress_key = "non_conservative_stress" else: self._nc_forces_key = "non_conservative_forces" self._nc_stress_key = "non_conservative_stress" - # Validate that NC keys exist in model capabilities (only for - # outputs that will actually be requested) - if non_conservative: - if compute_forces and self._nc_forces_key not in outputs: - raise ValueError( - f"model does not have '{self._nc_forces_key}' output, " - "required for non_conservative=True with compute_forces=True" - ) - if compute_stress and self._nc_stress_key not in outputs: - raise ValueError( - f"model does not have '{self._nc_stress_key}' output, " - "required for non_conservative=True with compute_stress=True" - ) - # Additional outputs if additional_outputs is None: self._additional_output_requests: Dict[str, ModelOutput] = {} @@ -249,9 +242,11 @@ def __init__( raise TypeError( f"additional_outputs keys must be strings, got {type(name)}" ) - if not isinstance(output, torch.ScriptObject) or not hasattr( - output, "_method_names" - ) or "explicit_gradients_setter" not in output._method_names(): + if ( + not isinstance(output, torch.ScriptObject) + or not hasattr(output, "_method_names") + or "explicit_gradients_setter" not in output._method_names() + ): raise TypeError( f"additional_outputs['{name}'] must be a ModelOutput " f"instance, got {type(output)}" @@ -280,9 +275,7 @@ def __init__( # Precompute the outputs dict (immutable after __init__) run_outputs: Dict[str, ModelOutput] = { - self._energy_key: ModelOutput( - quantity="energy", unit="eV", per_atom=False - ), + self._energy_key: ModelOutput(quantity="energy", unit="eV", per_atom=False), } if self._calculate_uncertainty: run_outputs[self._energy_uq_key] = ModelOutput( diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index a987ab9d3..2135028f6 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -335,9 +335,7 @@ def test_uncertainty_warning_emitted(lj_model, ni_atoms): """Uncertainty warning fires when atoms exceed threshold.""" # LJ test model's pseudo-uncertainty is 0.001 * n_atoms^2. # For 32 atoms: 0.001 * 32^2 = 1.024 per atom. Set threshold below that. - model = MetatomicModel( - model=lj_model, device=DEVICE, uncertainty_threshold=0.5 - ) + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=0.5) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -349,9 +347,7 @@ def test_uncertainty_warning_emitted(lj_model, ni_atoms): def test_uncertainty_no_warning_high_threshold(lj_model, ni_atoms): """No warning when threshold is above all uncertainties.""" - model = MetatomicModel( - model=lj_model, device=DEVICE, uncertainty_threshold=1e6 - ) + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=1e6) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -362,9 +358,7 @@ def test_uncertainty_no_warning_high_threshold(lj_model, ni_atoms): def test_uncertainty_threshold_none(lj_model, ni_atoms): """Setting uncertainty_threshold=None disables UQ entirely.""" - model = MetatomicModel( - model=lj_model, device=DEVICE, uncertainty_threshold=None - ) + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -376,9 +370,7 @@ def test_uncertainty_threshold_none(lj_model, ni_atoms): def test_negative_uncertainty_threshold_raises(lj_model): """Negative uncertainty_threshold raises ValueError.""" with pytest.raises(ValueError, match="must be positive"): - MetatomicModel( - model=lj_model, device=DEVICE, uncertainty_threshold=-0.1 - ) + MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=-0.1) # ---- Additional outputs ---- @@ -395,13 +387,9 @@ def test_additional_outputs_empty(lj_model, ni_atoms): def test_additional_outputs_requested(lj_model, ni_atoms): """Extra model outputs are stored in additional_outputs.""" extra = { - "energy_ensemble": ModelOutput( - quantity="energy", unit="eV", per_atom=True - ), + "energy_ensemble": ModelOutput(quantity="energy", unit="eV", per_atom=True), } - model = MetatomicModel( - model=lj_model, device=DEVICE, additional_outputs=extra - ) + model = MetatomicModel(model=lj_model, device=DEVICE, additional_outputs=extra) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) model(sim_state) @@ -416,9 +404,7 @@ def test_additional_outputs_requested(lj_model, ni_atoms): def test_non_conservative_forces(lj_model, ni_atoms): """NC forces are returned without autograd.""" - model = MetatomicModel( - model=lj_model, device=DEVICE, non_conservative=True - ) + model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -433,9 +419,7 @@ def test_non_conservative_forces(lj_model, ni_atoms): def test_non_conservative_stress(lj_model, ni_atoms): """NC stress is returned with correct shape.""" - model = MetatomicModel( - model=lj_model, device=DEVICE, non_conservative=True - ) + model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -445,9 +429,7 @@ def test_non_conservative_stress(lj_model, ni_atoms): def test_non_conservative_batched_forces(lj_model, ni_atoms): """NC net-force subtraction is per-system in batched mode.""" - model = MetatomicModel( - model=lj_model, device=DEVICE, non_conservative=True - ) + model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) ni_atoms_2 = ni_atoms.copy() ni_atoms_2.positions += 0.3 * np.random.rand(*ni_atoms_2.positions.shape) @@ -468,10 +450,8 @@ def test_non_conservative_batched_forces(lj_model, ni_atoms): def test_non_conservative_missing_output_raises(lj_model_ext): """ValueError when model lacks NC outputs.""" - with pytest.raises(ValueError, match="does not have"): - MetatomicModel( - model=lj_model_ext, device=DEVICE, non_conservative=True - ) + with pytest.raises((ValueError, RuntimeError), match="not found"): + MetatomicModel(model=lj_model_ext, device=DEVICE, non_conservative=True) def test_non_conservative_with_variants(lj_model, ni_atoms): From 9bc9a2d1cf5c2553506493edc24f70f5b3618160 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Tue, 17 Mar 2026 23:53:57 +0100 Subject: [PATCH 06/23] fix(tests): use pytest.warns for uncertainty warning test filterwarnings = ["error"] converts warnings to exceptions, so warnings.catch_warnings(record=True) never captures them. Use pytest.warns(UserWarning) which properly overrides the filter. --- python/metatomic_torchsim/tests/torchsim.py | 22 +++++---------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 2135028f6..dc238fbf3 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -6,8 +6,6 @@ uncertainty, and "/doubled" variants for full feature testing. """ -import warnings - import numpy as np import pytest import torch @@ -337,34 +335,24 @@ def test_uncertainty_warning_emitted(lj_model, ni_atoms): # For 32 atoms: 0.001 * 32^2 = 1.024 per atom. Set threshold below that. model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=0.5) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with pytest.warns(UserWarning, match="uncertainty"): model(sim_state) - uq_warnings = [x for x in w if "uncertainty" in str(x.message).lower()] - assert len(uq_warnings) == 1 - assert "threshold" in str(uq_warnings[0].message) def test_uncertainty_no_warning_high_threshold(lj_model, ni_atoms): """No warning when threshold is above all uncertainties.""" model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=1e6) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - model(sim_state) - uq_warnings = [x for x in w if "uncertainty" in str(x.message).lower()] - assert len(uq_warnings) == 0 + # Should not warn -- high threshold above all uncertainty values + model(sim_state) def test_uncertainty_threshold_none(lj_model, ni_atoms): """Setting uncertainty_threshold=None disables UQ entirely.""" model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - model(sim_state) - uq_warnings = [x for x in w if "uncertainty" in str(x.message).lower()] - assert len(uq_warnings) == 0 + # Should not warn -- UQ disabled + model(sim_state) def test_negative_uncertainty_threshold_raises(lj_model): From ce0fc7a7307ad72ac1f158766cca4a63d1590c17 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 18 Mar 2026 00:13:17 +0100 Subject: [PATCH 07/23] fix(tests): use tiny threshold to guarantee uncertainty warning fires --- python/metatomic_torchsim/tests/torchsim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index dc238fbf3..3e127f153 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -331,9 +331,9 @@ def test_variants_doubled(lj_model, ni_atoms): def test_uncertainty_warning_emitted(lj_model, ni_atoms): """Uncertainty warning fires when atoms exceed threshold.""" - # LJ test model's pseudo-uncertainty is 0.001 * n_atoms^2. - # For 32 atoms: 0.001 * 32^2 = 1.024 per atom. Set threshold below that. - model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=0.5) + # LJ test model pseudo-uncertainty scales with system size. + # Use a very small threshold to guarantee it fires. + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=1e-10) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) with pytest.warns(UserWarning, match="uncertainty"): model(sim_state) From fc013896be3291e34f229f5b963eff8732a61b96 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 18 Mar 2026 00:46:25 +0100 Subject: [PATCH 08/23] fix(tests): add filterwarnings marker so pytest.warns captures the warning --- python/metatomic_torchsim/tests/torchsim.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 3e127f153..736b284dc 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -329,6 +329,7 @@ def test_variants_doubled(lj_model, ni_atoms): # ---- Uncertainty ---- +@pytest.mark.filterwarnings("default::UserWarning") def test_uncertainty_warning_emitted(lj_model, ni_atoms): """Uncertainty warning fires when atoms exceed threshold.""" # LJ test model pseudo-uncertainty scales with system size. From 78c804812ba79730b098a316c207e2e3ca18941d Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 18 Mar 2026 01:54:11 +0100 Subject: [PATCH 09/23] fix(tests): use pytest.raises for warning test (filterwarnings=error) --- python/metatomic_torchsim/tests/torchsim.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 736b284dc..8d7ee18d5 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -329,14 +329,15 @@ def test_variants_doubled(lj_model, ni_atoms): # ---- Uncertainty ---- -@pytest.mark.filterwarnings("default::UserWarning") def test_uncertainty_warning_emitted(lj_model, ni_atoms): """Uncertainty warning fires when atoms exceed threshold.""" # LJ test model pseudo-uncertainty scales with system size. # Use a very small threshold to guarantee it fires. + # filterwarnings = ["error"] converts warnings to exceptions, + # so we catch it as an error. model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=1e-10) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - with pytest.warns(UserWarning, match="uncertainty"): + with pytest.raises(UserWarning, match="uncertainty"): model(sim_state) From 626f0f4347785a5d7e5b486756fee5f7a269accf Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 18 Mar 2026 02:46:09 +0100 Subject: [PATCH 10/23] fix(torchsim): handle bool pbc from torch-sim, fix warning test regex - Convert state.pbc (bool or array) to Tensor for System() constructor - Fix test regex to match "uncertainties are larger" (pytest.raises matches against str(exception) which wraps differently) Verified: 29 passed, 1 skipped on rg.cosmolab --- python/metatomic_torchsim/metatomic_torchsim/_model.py | 8 +++++++- python/metatomic_torchsim/tests/torchsim.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 0c8c02226..da7bc6e32 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -353,12 +353,18 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: sys_cell = sys_cell @ strain strains.append(strain) + pbc = state.pbc + if isinstance(pbc, bool): + pbc = torch.tensor([pbc, pbc, pbc]) + elif not isinstance(pbc, torch.Tensor): + pbc = torch.tensor(pbc) + systems.append( System( positions=sys_positions, types=sys_types, cell=sys_cell, - pbc=state.pbc, + pbc=pbc, ) ) diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 8d7ee18d5..101971b8b 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -337,7 +337,7 @@ def test_uncertainty_warning_emitted(lj_model, ni_atoms): # so we catch it as an error. model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=1e-10) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - with pytest.raises(UserWarning, match="uncertainty"): + with pytest.raises(UserWarning, match="uncertainties are larger"): model(sim_state) From da6b9616e15555eee93efab94f44ea91fc6046d1 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 18 Mar 2026 20:08:08 +0100 Subject: [PATCH 11/23] fix(torchsim): address PR #181 review comments - Remove doc sub-pages (architecture already in Architecture.md, getting-started/batched belong as tutorials, model-loading goes in docstring) - Remove toctree and "output variants" / "additional outputs" bullet from torch-sim.rst (expected for all engines) - Revert ts.io.atoms_to_state back to ts.initialize_state (torchsim docs convention) - Reformulate NC docstring to match ASE calculator wording - Fix variants docstring example per suggestion - Truncate uncertainty warning atom list to first 20 Verified: 29 passed, 1 skipped on rg.cosmolab --- docs/src/engines/torch-sim-architecture.rst | 87 ------------------- docs/src/engines/torch-sim-batched.rst | 79 ----------------- .../src/engines/torch-sim-getting-started.rst | 87 ------------------- docs/src/engines/torch-sim-model-loading.rst | 73 ---------------- docs/src/engines/torch-sim.rst | 15 +--- .../metatomic_torchsim/_model.py | 19 ++-- 6 files changed, 15 insertions(+), 345 deletions(-) delete mode 100644 docs/src/engines/torch-sim-architecture.rst delete mode 100644 docs/src/engines/torch-sim-batched.rst delete mode 100644 docs/src/engines/torch-sim-getting-started.rst delete mode 100644 docs/src/engines/torch-sim-model-loading.rst diff --git a/docs/src/engines/torch-sim-architecture.rst b/docs/src/engines/torch-sim-architecture.rst deleted file mode 100644 index caa4a1451..000000000 --- a/docs/src/engines/torch-sim-architecture.rst +++ /dev/null @@ -1,87 +0,0 @@ -.. _torchsim-architecture: - -Architecture -============ - -This page explains how ``MetatomicModel`` bridges TorchSim and -metatomic. - -SimState vs list of System --------------------------- - -TorchSim represents a simulation as a single batched ``SimState`` -containing all atoms from all systems, with a ``system_idx`` tensor -tracking ownership. Metatomic expects a ``list[System]`` where each -``System`` holds one periodic structure. - -``MetatomicModel.forward`` converts between these representations: - -1. Split the batched positions and atomic numbers by ``system_idx`` -2. Create one ``System`` per sub-structure with its own cell -3. Call the model on the list of systems -4. Concatenate results back into batched tensors - -Forces via autograd -------------------- - -Metatomic models typically output only total energies. Forces are -computed as the negative gradient of the energy with respect to atomic -positions:: - - F_i = -dE/dr_i - -Before calling the model, each system's positions are detached and set -to ``requires_grad_(True)``. After the forward pass, -``torch.autograd.grad`` computes the derivatives. - -When ``non_conservative=True``, forces and stresses are read directly -from the model's ``non_conservative_forces`` and -``non_conservative_stress`` outputs, bypassing autograd entirely. - -Stress via the strain trick ---------------------------- - -Stress is computed using the Knuth strain trick. An identity strain -tensor (3x3, ``requires_grad=True``) is applied to both positions and -cell vectors:: - - r' = r @ strain - h' = h @ strain - -The stress per system is then:: - - sigma = (1/V) * dE/d(strain) - -where V is the cell volume. This gives the full 3x3 stress tensor -without finite differences. - -Neighbor lists --------------- - -Models specify what neighbor lists they need via -``model.requested_neighbor_lists()``, which returns a list of -``NeighborListOptions`` (cutoff radius, full vs half list). - -The wrapper computes these using: - -- **vesin**: Default backend for both CPU and GPU. Handles half and - full neighbor lists. Systems on non-CPU/CUDA devices are temporarily - moved to CPU for the computation. -- **nvalchemiops**: Used automatically on CUDA for full neighbor lists - when installed. Keeps everything on GPU, avoiding host-device - transfers. - -The decision happens per-call in ``_compute_requested_neighbors``: if -all systems are on CUDA and nvalchemiops is available, full-list -requests go through nvalchemi while half-list requests still use vesin. - -Why a separate package ----------------------- - -metatomic-torchsim has its own versioning, release schedule, and -dependency set (``torch-sim-atomistic``). Keeping it separate from -metatomic-torch avoids forcing a torch-sim dependency on users who only -need the ASE calculator or other integrations. - -The package is pure Python with no compiled extensions, making it -lightweight to install. diff --git a/docs/src/engines/torch-sim-batched.rst b/docs/src/engines/torch-sim-batched.rst deleted file mode 100644 index c214b73ce..000000000 --- a/docs/src/engines/torch-sim-batched.rst +++ /dev/null @@ -1,79 +0,0 @@ -.. _torchsim-batched: - -Batched simulations -=================== - -TorchSim supports batching multiple systems into a single ``SimState`` -for efficient parallel evaluation on GPU. ``MetatomicModel`` handles -this transparently. - -Creating a batched state ------------------------- - -Pass a list of ASE ``Atoms`` objects to ``atoms_to_state``: - -.. code-block:: python - - import ase.build - import torch_sim as ts - from metatomic_torchsim import MetatomicModel - - model = MetatomicModel("model.pt", device="cpu") - - atoms_list = [ - ase.build.bulk("Cu", "fcc", a=3.6, cubic=True), - ase.build.bulk("Ni", "fcc", a=3.52, cubic=True), - ase.build.bulk("Al", "fcc", a=4.05, cubic=True), - ] - - sim_state = ts.io.atoms_to_state(atoms_list, model.device, model.dtype) - -Evaluating the batch --------------------- - -A single forward call evaluates all systems: - -.. code-block:: python - - results = model(sim_state) - -The output shapes reflect the batch: - -- ``results["energy"]`` has shape ``[n_systems]`` (one energy per system) -- ``results["forces"]`` has shape ``[n_total_atoms, 3]`` (all atoms - concatenated) -- ``results["stress"]`` has shape ``[n_systems, 3, 3]`` (one 3x3 tensor - per system) - -How system_idx works --------------------- - -``SimState`` tracks which atom belongs to which system via the -``system_idx`` tensor. For three 4-atom systems, ``system_idx`` looks -like:: - - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] - -``MetatomicModel.forward`` uses this to split the batched positions and -types into per-system ``System`` objects before calling the underlying -model. - -Batch consistency ------------------ - -Energies computed in a batch match those computed individually. This is -guaranteed because each system gets its own neighbor list and -independent evaluation. The existing test -``test_energy_consistency_single_vs_batch`` validates this property. - -Performance considerations --------------------------- - -Batching is most beneficial on GPU, where the neighbor list computation -and model forward pass can run in parallel across systems. On CPU, the -speedup comes from reduced Python overhead (one call instead of N). - -For very large systems or many small ones, adjust the batch size to fit -in GPU memory. TorchSim does not impose a maximum batch size, but each -system gets its own neighbor list, so memory scales with the sum of -per-system sizes. diff --git a/docs/src/engines/torch-sim-getting-started.rst b/docs/src/engines/torch-sim-getting-started.rst deleted file mode 100644 index 96249a49b..000000000 --- a/docs/src/engines/torch-sim-getting-started.rst +++ /dev/null @@ -1,87 +0,0 @@ -.. _torchsim-getting-started: - -Getting started -=============== - -This tutorial walks through running a short NVE molecular dynamics -simulation with a metatomic model and TorchSim. - -Prerequisites -------------- - -Install the package and its dependencies: - -.. code-block:: bash - - pip install metatomic-torchsim - -Load the model --------------- - -.. code-block:: python - - from metatomic_torchsim import MetatomicModel - - model = MetatomicModel("path/to/model.pt", device="cpu") - -The wrapper detects the model's dtype and supported devices -automatically. Pass ``device="cuda"`` to run on GPU. - -Build a simulation state ------------------------- - -TorchSim works with ``SimState`` objects. Convert ASE ``Atoms`` using -``torch_sim.io.atoms_to_state``: - -.. code-block:: python - - import ase.build - import torch_sim as ts - - atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) - sim_state = ts.io.atoms_to_state([atoms], model.device, model.dtype) - -Evaluate the model ------------------- - -Call the model on the simulation state to get energies, forces, and -stresses: - -.. code-block:: python - - results = model(sim_state) - - print("Energy:", results["energy"]) # shape [1] - print("Forces:", results["forces"]) # shape [n_atoms, 3] - print("Stress:", results["stress"]) # shape [1, 3, 3] - -Run NVE dynamics ----------------- - -Use TorchSim's Velocity Verlet integrator: - -.. code-block:: python - - from torch_sim.integrators import VelocityVerletIntegrator - - integrator = VelocityVerletIntegrator( - model=model, - state=sim_state, - dt=1.0, # femtoseconds - ) - - for step in range(100): - sim_state = integrator.step(sim_state) - if step % 10 == 0: - energy = model(sim_state)["energy"].item() - print(f"Step {step:3d} E = {energy:.4f} eV") - -The total energy should remain approximately constant in an NVE -simulation, which serves as a basic sanity check for your model. - -Next steps ----------- - -- :ref:`torchsim-model-loading` covers all supported input formats -- :ref:`torchsim-batched` explains running multiple systems at once -- :ref:`torchsim-architecture` describes the internals diff --git a/docs/src/engines/torch-sim-model-loading.rst b/docs/src/engines/torch-sim-model-loading.rst deleted file mode 100644 index 1cfa371d8..000000000 --- a/docs/src/engines/torch-sim-model-loading.rst +++ /dev/null @@ -1,73 +0,0 @@ -.. _torchsim-model-loading: - -Loading models -============== - -``MetatomicModel`` accepts several input formats. Each section below -shows one loading pattern. - -From a saved ``.pt`` file -------------------------- - -The most common case. Pass the path to a TorchScript-exported metatomic -model: - -.. code-block:: python - - from metatomic_torchsim import MetatomicModel - - model = MetatomicModel("path/to/model.pt", device="cpu") - -The file must exist and contain a valid ``AtomisticModel``. A -``ValueError`` is raised if the path does not exist. - -From a Python AtomisticModel ------------------------------ - -If you already have an ``AtomisticModel`` instance (for example, built -programmatically): - -.. code-block:: python - - from metatomic.torch import AtomisticModel - - atomistic_model = build_my_model() # returns AtomisticModel - model = MetatomicModel(atomistic_model, device="cuda") - -From a TorchScript RecursiveScriptModule ------------------------------------------ - -If you have a scripted model loaded via ``torch.jit.load``: - -.. code-block:: python - - import torch - - scripted = torch.jit.load("model.pt") - model = MetatomicModel(scripted, device="cpu") - -The script module must have ``original_name == "AtomisticModel"``. -Otherwise a ``TypeError`` is raised. - -Selecting a device ------------------- - -By default, ``MetatomicModel`` picks the best device from the model's -``supported_devices``. Override with the ``device`` parameter: - -.. code-block:: python - - model = MetatomicModel("model.pt", device="cuda:0") - -Extensions directory --------------------- - -Some models require compiled TorchScript extensions. Point to their -location with ``extensions_directory``: - -.. code-block:: python - - model = MetatomicModel( - "model.pt", - extensions_directory="path/to/extensions/", - ) diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 4d396565f..0622faada 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -28,12 +28,10 @@ Supported model outputs The :ref:`energy ` output is the primary output. Forces and stresses are derived via autograd by default. The wrapper also supports: -- **Output variants**: select among model variants (e.g. ``"energy/pbe"``) -- **Non-conservative forces/stress**: read directly from model outputs instead +- **Non-conservative forces/stress**: use direct prediction of gradients instead of autograd (``non_conservative=True``) - **Energy uncertainty**: per-atom uncertainty warnings when the model provides an ``energy_uncertainty`` output -- **Additional outputs**: request arbitrary extra model outputs How to use the code ^^^^^^^^^^^^^^^^^^^ @@ -47,7 +45,7 @@ How to use the code model = MetatomicModel("model.pt", device="cpu") atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) - sim_state = ts.io.atoms_to_state([atoms], model.device, model.dtype) + sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) results = model(sim_state) print(results["energy"]) # shape [1] @@ -60,12 +58,3 @@ API documentation .. autoclass:: metatomic_torchsim.MetatomicModel :show-inheritance: :members: - -.. toctree:: - :maxdepth: 2 - :caption: torch-sim integration - - torch-sim-getting-started - torch-sim-model-loading - torch-sim-batched - torch-sim-architecture diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index da7bc6e32..17100a055 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -101,11 +101,12 @@ def __init__( be used. Setting ``{"energy": "pbe"}`` selects the ``"energy/pbe"`` output. The energy variant propagates to uncertainty and non-conservative outputs unless overridden (e.g. - ``{"energy_uncertainty": "r2scan"}``). - :param non_conservative: If ``True``, forces and stresses are read - directly from the model's ``non_conservative_forces`` and - ``non_conservative_stress`` outputs instead of being computed via - autograd. + ``{"energy": "pbe", "energy_uncertainty": "r2scan"}`` would select + ``energy/pbe`` and ``energy_uncertainty/r2scan``). + :param non_conservative: If ``True``, the model will be asked to compute + non-conservative forces and stresses. This can afford a speed-up, + potentially at the expense of physical correctness (especially in + molecular dynamics simulations). :param uncertainty_threshold: Threshold for per-atom energy uncertainty in eV. When the model supports ``energy_uncertainty`` with ``per_atom=True``, atoms exceeding this threshold trigger a warning. @@ -398,10 +399,16 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: threshold = self._uncertainty_threshold if torch.any(uncertainty > threshold): exceeded = torch.where(uncertainty.squeeze(-1) > threshold)[0] + atom_list = exceeded.tolist() + if len(atom_list) > 20: + atom_list = atom_list[:20] + suffix = f" (and {len(exceeded) - 20} more)" + else: + suffix = "" warnings.warn( "Some of the atomic energy uncertainties are larger than the " f"threshold of {threshold} eV. The prediction is above the " - f"threshold for atoms {exceeded.tolist()}.", + f"threshold for atoms {atom_list}{suffix}.", stacklevel=2, ) From 160cc23ae06c3820e8cb38521da76643e9bc7940 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 18 Mar 2026 20:31:17 +0100 Subject: [PATCH 12/23] fix(torchsim): final review cleanup - Remove stale CHANGELOG entry about deleted doc sub-pages - Hoist pbc bool-to-Tensor conversion above per-system loop - Convert uncertainty shape assert to ValueError (survives -O) --- python/metatomic_torchsim/CHANGELOG.md | 2 -- .../metatomic_torchsim/_model.py | 21 ++++++++++--------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/python/metatomic_torchsim/CHANGELOG.md b/python/metatomic_torchsim/CHANGELOG.md index 0e5952b1e..027e53b5e 100644 --- a/python/metatomic_torchsim/CHANGELOG.md +++ b/python/metatomic_torchsim/CHANGELOG.md @@ -28,5 +28,3 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Per-atom energy uncertainty warnings via `uncertainty_threshold`, triggered when the model provides `energy_uncertainty` with `per_atom=True` - `additional_outputs` parameter for requesting arbitrary extra model outputs -- Documentation sub-pages: getting started, model loading, batched simulations, - architecture diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 17100a055..28b8cfedb 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -334,6 +334,12 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: strains: List[torch.Tensor] = [] n_systems = len(cell) + pbc = state.pbc + if isinstance(pbc, bool): + pbc = torch.tensor([pbc, pbc, pbc]) + elif not isinstance(pbc, torch.Tensor): + pbc = torch.tensor(pbc) + for sys_idx in range(n_systems): mask = state.system_idx == sys_idx sys_positions = positions[mask] @@ -354,12 +360,6 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: sys_cell = sys_cell @ strain strains.append(strain) - pbc = state.pbc - if isinstance(pbc, bool): - pbc = torch.tensor([pbc, pbc, pbc]) - elif not isinstance(pbc, torch.Tensor): - pbc = torch.tensor(pbc) - systems.append( System( positions=sys_positions, @@ -392,10 +392,11 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: if self._calculate_uncertainty: uncertainty = model_outputs[self._energy_uq_key].block().values n_total_atoms = positions.shape[0] - assert uncertainty.shape == (n_total_atoms, 1), ( - f"expected uncertainty shape ({n_total_atoms}, 1), " - f"got {uncertainty.shape}" - ) + if uncertainty.shape != (n_total_atoms, 1): + raise ValueError( + f"expected uncertainty shape ({n_total_atoms}, 1), " + f"got {uncertainty.shape}" + ) threshold = self._uncertainty_threshold if torch.any(uncertainty > threshold): exceeded = torch.where(uncertainty.squeeze(-1) > threshold)[0] From 7e53f7e61a8811368bf24636000084cf0bf94fcd Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 18 Mar 2026 20:33:50 +0100 Subject: [PATCH 13/23] fix(torchsim): address all remaining review items - Always resolve both NC keys via pick_output when non_conservative=True (match ASE behavior: validate model supports outputs at construction) - Simplify additional_outputs validation to match ASE pattern (assert isinstance ScriptObject, no internal method duck-typing) - Add test: stress-only NC mode (compute_forces=False) - Add test: NC variant doubled verifies 2x energy - Add test: invalid additional_outputs raises AssertionError Verified: 31 passed, 1 skipped on rg.cosmolab --- .../metatomic_torchsim/_model.py | 45 ++++++------------ python/metatomic_torchsim/tests/torchsim.py | 47 ++++++++++++++++--- 2 files changed, 56 insertions(+), 36 deletions(-) diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 28b8cfedb..be53da4ec 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -213,23 +213,16 @@ def __init__( "must either be both `None` or both not `None`." ) - if compute_forces: - self._nc_forces_key = pick_output( - "non_conservative_forces", - outputs, - resolved_variants["non_conservative_forces"], - ) - else: - self._nc_forces_key = "non_conservative_forces" - - if compute_stress: - self._nc_stress_key = pick_output( - "non_conservative_stress", - outputs, - resolved_variants["non_conservative_stress"], - ) - else: - self._nc_stress_key = "non_conservative_stress" + self._nc_forces_key = pick_output( + "non_conservative_forces", + outputs, + resolved_variants["non_conservative_forces"], + ) + self._nc_stress_key = pick_output( + "non_conservative_stress", + outputs, + resolved_variants["non_conservative_stress"], + ) else: self._nc_forces_key = "non_conservative_forces" self._nc_stress_key = "non_conservative_stress" @@ -238,20 +231,12 @@ def __init__( if additional_outputs is None: self._additional_output_requests: Dict[str, ModelOutput] = {} else: + assert isinstance(additional_outputs, dict) for name, output in additional_outputs.items(): - if not isinstance(name, str): - raise TypeError( - f"additional_outputs keys must be strings, got {type(name)}" - ) - if ( - not isinstance(output, torch.ScriptObject) - or not hasattr(output, "_method_names") - or "explicit_gradients_setter" not in output._method_names() - ): - raise TypeError( - f"additional_outputs['{name}'] must be a ModelOutput " - f"instance, got {type(output)}" - ) + assert isinstance(name, str) + assert isinstance(output, torch.ScriptObject), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs self._model = model.to(device=self._device) diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 101971b8b..33acbead5 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -444,9 +444,27 @@ def test_non_conservative_missing_output_raises(lj_model_ext): MetatomicModel(model=lj_model_ext, device=DEVICE, non_conservative=True) -def test_non_conservative_with_variants(lj_model, ni_atoms): - """NC outputs respect variant selection.""" +def test_non_conservative_stress_only(lj_model, ni_atoms): + """NC mode with compute_forces=False returns only stress.""" model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + compute_forces=False, + ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) + output = model(sim_state) + + assert "energy" in output + assert "forces" not in output + assert "stress" in output + assert output["stress"].shape == (1, 3, 3) + + +def test_non_conservative_with_variants(lj_model, ni_atoms): + """NC doubled variant gives different forces than base variant.""" + model_base = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) + model_doubled = MetatomicModel( model=lj_model, device=DEVICE, non_conservative=True, @@ -456,9 +474,26 @@ def test_non_conservative_with_variants(lj_model, ni_atoms): "non_conservative_stress": "doubled", }, ) + sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) - output = model(sim_state) + out_base = model_base(sim_state) + out_doubled = model_doubled(sim_state) - assert "energy" in output - assert "forces" in output - assert "stress" in output + assert "energy" in out_doubled + assert "forces" in out_doubled + assert "stress" in out_doubled + + # Doubled energy should be 2x base + torch.testing.assert_close( + out_doubled["energy"], 2.0 * out_base["energy"], atol=1e-10, rtol=0 + ) + + +def test_additional_outputs_invalid_raises(lj_model): + """Passing non-ModelOutput values raises AssertionError.""" + with pytest.raises(AssertionError): + MetatomicModel( + model=lj_model, + device=DEVICE, + additional_outputs={"bad": "not a ModelOutput"}, + ) From 66ca9f9402b3efd3cea77e8cc404499ea4d8b5f5 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 25 Mar 2026 00:13:46 +0100 Subject: [PATCH 14/23] fix(torchsim): address new review comments from 2026-03-24 - Default uncertainty_threshold=0.1 to match ASE calculator - Add additional_outputs bullet to supported outputs docs with link to API documentation - Mention r-RESPA / multiple time stepping in NC docstring - All non-UQ tests pass uncertainty_threshold=None to avoid warnings-as-errors with the pure-PyTorch LJ model --- docs/src/engines/torch-sim.rst | 9 +- .../metatomic_torchsim/_model.py | 6 +- python/metatomic_torchsim/tests/torchsim.py | 88 +++++++++++++++---- 3 files changed, 84 insertions(+), 19 deletions(-) diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 0622faada..8ff977dcf 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -26,12 +26,19 @@ Supported model outputs ^^^^^^^^^^^^^^^^^^^^^^^ The :ref:`energy ` output is the primary output. Forces and -stresses are derived via autograd by default. The wrapper also supports: +stresses are derived via autograd by default. The wrapper also supports: - **Non-conservative forces/stress**: use direct prediction of gradients instead of autograd (``non_conservative=True``) - **Energy uncertainty**: per-atom uncertainty warnings when the model provides an ``energy_uncertainty`` output +- **Additional outputs**: request arbitrary extra model outputs via + ``additional_outputs``; results are stored as + :py:class:`metatensor.torch.TensorMap` in the + :py:attr:`~metatomic_torchsim.MetatomicModel.additional_outputs` attribute + +See the :py:class:`~metatomic_torchsim.MetatomicModel` API documentation below +for details on all parameters. How to use the code ^^^^^^^^^^^^^^^^^^^ diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index be53da4ec..fa434444a 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -82,7 +82,7 @@ def __init__( compute_stress: bool = True, variants: Optional[Dict[str, Optional[str]]] = None, non_conservative: bool = False, - uncertainty_threshold: Optional[float] = None, + uncertainty_threshold: Optional[float] = 0.1, additional_outputs: Optional[Dict[str, ModelOutput]] = None, ) -> None: """ @@ -106,7 +106,9 @@ def __init__( :param non_conservative: If ``True``, the model will be asked to compute non-conservative forces and stresses. This can afford a speed-up, potentially at the expense of physical correctness (especially in - molecular dynamics simulations). + molecular dynamics simulations). Non-conservative outputs are also + useful for multiple time stepping schemes (e.g. r-RESPA), where + different force components are evaluated at different frequencies. :param uncertainty_threshold: Threshold for per-atom energy uncertainty in eV. When the model supports ``energy_uncertainty`` with ``per_atom=True``, atoms exceeding this threshold trigger a warning. diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 33acbead5..775c912e8 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -67,7 +67,7 @@ def ni_atoms(): @pytest.fixture def metatomic_model(lj_model): - return MetatomicModel(model=lj_model, device=DEVICE) + return MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) def test_initialization(lj_model): @@ -81,7 +81,12 @@ def test_initialization(lj_model): def test_initialization_no_forces(lj_model): """Can disable force computation.""" - model = MetatomicModel(model=lj_model, device=DEVICE, compute_forces=False) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + compute_forces=False, + uncertainty_threshold=None, + ) assert model.compute_forces is False assert model.compute_stress is True @@ -119,7 +124,9 @@ def test_forward_returns_stress(metatomic_model, ni_atoms): def test_forward_no_stress(lj_model, ni_atoms): """Stress is not returned when compute_stress=False.""" - model = MetatomicModel(model=lj_model, device=DEVICE, compute_stress=False) + model = MetatomicModel( + model=lj_model, device=DEVICE, compute_stress=False, uncertainty_threshold=None + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -130,7 +137,12 @@ def test_forward_no_stress(lj_model, ni_atoms): def test_forward_no_forces(lj_model, ni_atoms): """Forces are not returned when compute_forces=False.""" - model = MetatomicModel(model=lj_model, device=DEVICE, compute_forces=False) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + compute_forces=False, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -227,7 +239,7 @@ def test_single_atom_system(lj_model): cell=[10.0, 10.0, 10.0], pbc=True, ) - model = MetatomicModel(model=lj_model, device=DEVICE) + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) sim_state = ts.io.atoms_to_state([atoms], DEVICE, DTYPE) output = model(sim_state) @@ -251,7 +263,12 @@ def test_energy_only_mode(lj_model, ni_atoms): def test_check_consistency_mode(lj_model, ni_atoms): """Model runs with consistency checking enabled.""" - model = MetatomicModel(model=lj_model, device=DEVICE, check_consistency=True) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + check_consistency=True, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -263,7 +280,9 @@ def test_check_consistency_mode(lj_model, ni_atoms): def test_forces_match_finite_difference(lj_model, ni_atoms): """Autograd forces match finite-difference gradient of energy.""" delta = 1e-4 - model = MetatomicModel(model=lj_model, device=DEVICE, compute_stress=False) + model = MetatomicModel( + model=lj_model, device=DEVICE, compute_stress=False, uncertainty_threshold=None + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) autograd_forces = output["forces"] @@ -304,7 +323,12 @@ def test_stress_is_symmetric(metatomic_model, ni_atoms): def test_variants_default(lj_model, ni_atoms): """Default variant (None) selects the base energy output.""" - model = MetatomicModel(model=lj_model, device=DEVICE, variants={"energy": None}) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + variants={"energy": None}, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -314,9 +338,14 @@ def test_variants_default(lj_model, ni_atoms): def test_variants_doubled(lj_model, ni_atoms): """Selecting the 'doubled' variant gives 2x the base energy.""" - model_base = MetatomicModel(model=lj_model, device=DEVICE) + model_base = MetatomicModel( + model=lj_model, device=DEVICE, uncertainty_threshold=None + ) model_doubled = MetatomicModel( - model=lj_model, device=DEVICE, variants={"energy": "doubled"} + model=lj_model, + device=DEVICE, + variants={"energy": "doubled"}, + uncertainty_threshold=None, ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) @@ -368,7 +397,7 @@ def test_negative_uncertainty_threshold_raises(lj_model): def test_additional_outputs_empty(lj_model, ni_atoms): """additional_outputs defaults to empty dict.""" - model = MetatomicModel(model=lj_model, device=DEVICE) + model = MetatomicModel(model=lj_model, device=DEVICE, uncertainty_threshold=None) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) model(sim_state) assert model.additional_outputs == {} @@ -379,7 +408,12 @@ def test_additional_outputs_requested(lj_model, ni_atoms): extra = { "energy_ensemble": ModelOutput(quantity="energy", unit="eV", per_atom=True), } - model = MetatomicModel(model=lj_model, device=DEVICE, additional_outputs=extra) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + additional_outputs=extra, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) model(sim_state) @@ -394,7 +428,12 @@ def test_additional_outputs_requested(lj_model, ni_atoms): def test_non_conservative_forces(lj_model, ni_atoms): """NC forces are returned without autograd.""" - model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -409,7 +448,12 @@ def test_non_conservative_forces(lj_model, ni_atoms): def test_non_conservative_stress(lj_model, ni_atoms): """NC stress is returned with correct shape.""" - model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -419,7 +463,12 @@ def test_non_conservative_stress(lj_model, ni_atoms): def test_non_conservative_batched_forces(lj_model, ni_atoms): """NC net-force subtraction is per-system in batched mode.""" - model = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) + model = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) ni_atoms_2 = ni_atoms.copy() ni_atoms_2.positions += 0.3 * np.random.rand(*ni_atoms_2.positions.shape) @@ -451,6 +500,7 @@ def test_non_conservative_stress_only(lj_model, ni_atoms): device=DEVICE, non_conservative=True, compute_forces=False, + uncertainty_threshold=None, ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) @@ -463,11 +513,17 @@ def test_non_conservative_stress_only(lj_model, ni_atoms): def test_non_conservative_with_variants(lj_model, ni_atoms): """NC doubled variant gives different forces than base variant.""" - model_base = MetatomicModel(model=lj_model, device=DEVICE, non_conservative=True) + model_base = MetatomicModel( + model=lj_model, + device=DEVICE, + non_conservative=True, + uncertainty_threshold=None, + ) model_doubled = MetatomicModel( model=lj_model, device=DEVICE, non_conservative=True, + uncertainty_threshold=None, variants={ "energy": "doubled", "non_conservative_forces": "doubled", From 5d5f4cbf5c3cb72fa71d7876907042414fdf4e64 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 25 Mar 2026 00:18:46 +0100 Subject: [PATCH 15/23] fix(tests): add uncertainty_threshold=None to test_energy_only_mode --- python/metatomic_torchsim/tests/torchsim.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torchsim/tests/torchsim.py b/python/metatomic_torchsim/tests/torchsim.py index 775c912e8..03f229fc9 100644 --- a/python/metatomic_torchsim/tests/torchsim.py +++ b/python/metatomic_torchsim/tests/torchsim.py @@ -251,7 +251,11 @@ def test_single_atom_system(lj_model): def test_energy_only_mode(lj_model, ni_atoms): """Model returns only energy when forces and stress are disabled.""" model = MetatomicModel( - model=lj_model, device=DEVICE, compute_forces=False, compute_stress=False + model=lj_model, + device=DEVICE, + compute_forces=False, + compute_stress=False, + uncertainty_threshold=None, ) sim_state = ts.io.atoms_to_state([ni_atoms], DEVICE, DTYPE) output = model(sim_state) From f85bcae85bdd0b1b8db260c51604542efdb72434 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 25 Mar 2026 00:21:32 +0100 Subject: [PATCH 16/23] docs(torchsim): re-add getting-started and batched as tutorials Per review: these belong as tutorials under the torch-sim engine page. Architecture stays in Architecture.md (developer docs). Model-loading content distilled into the MetatomicModel docstring. --- docs/src/engines/torch-sim-batched.rst | 78 +++++++++++++++++ .../src/engines/torch-sim-getting-started.rst | 85 +++++++++++++++++++ docs/src/engines/torch-sim.rst | 7 ++ 3 files changed, 170 insertions(+) create mode 100644 docs/src/engines/torch-sim-batched.rst create mode 100644 docs/src/engines/torch-sim-getting-started.rst diff --git a/docs/src/engines/torch-sim-batched.rst b/docs/src/engines/torch-sim-batched.rst new file mode 100644 index 000000000..01a25d548 --- /dev/null +++ b/docs/src/engines/torch-sim-batched.rst @@ -0,0 +1,78 @@ +.. _torchsim-batched: + +Batched simulations +=================== + +TorchSim supports batching multiple systems into a single ``SimState`` +for efficient parallel evaluation on GPU. ``MetatomicModel`` handles +this transparently. + +Creating a batched state +------------------------ + +Pass a list of ASE ``Atoms`` objects to ``initialize_state``: + +.. code-block:: python + + import ase.build + import torch_sim as ts + from metatomic_torchsim import MetatomicModel + + model = MetatomicModel("model.pt", device="cpu") + + atoms_list = [ + ase.build.bulk("Cu", "fcc", a=3.6, cubic=True), + ase.build.bulk("Ni", "fcc", a=3.52, cubic=True), + ase.build.bulk("Al", "fcc", a=4.05, cubic=True), + ] + + sim_state = ts.initialize_state(atoms_list, device=model.device, dtype=model.dtype) + +Evaluating the batch +-------------------- + +A single forward call evaluates all systems: + +.. code-block:: python + + results = model(sim_state) + +The output shapes reflect the batch: + +- ``results["energy"]`` has shape ``[n_systems]`` (one energy per system) +- ``results["forces"]`` has shape ``[n_total_atoms, 3]`` (all atoms + concatenated) +- ``results["stress"]`` has shape ``[n_systems, 3, 3]`` (one 3x3 tensor + per system) + +How system_idx works +-------------------- + +``SimState`` tracks which atom belongs to which system via the +``system_idx`` tensor. For three 4-atom systems, ``system_idx`` looks +like:: + + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + +``MetatomicModel.forward`` uses this to split the batched positions and +types into per-system ``System`` objects before calling the underlying +model. + +Batch consistency +----------------- + +Energies computed in a batch match those computed individually. This is +guaranteed because each system gets its own neighbor list and +independent evaluation. + +Performance considerations +-------------------------- + +Batching is most beneficial on GPU, where the neighbor list computation +and model forward pass can run in parallel across systems. On CPU, the +speedup comes from reduced Python overhead (one call instead of N). + +For very large systems or many small ones, adjust the batch size to fit +in GPU memory. TorchSim does not impose a maximum batch size, but each +system gets its own neighbor list, so memory scales with the sum of +per-system sizes. diff --git a/docs/src/engines/torch-sim-getting-started.rst b/docs/src/engines/torch-sim-getting-started.rst new file mode 100644 index 000000000..afb085b58 --- /dev/null +++ b/docs/src/engines/torch-sim-getting-started.rst @@ -0,0 +1,85 @@ +.. _torchsim-getting-started: + +Getting started +=============== + +This tutorial walks through running a short NVE molecular dynamics +simulation with a metatomic model and TorchSim. + +Prerequisites +------------- + +Install the package and its dependencies: + +.. code-block:: bash + + pip install metatomic-torchsim + +Load the model +-------------- + +.. code-block:: python + + from metatomic_torchsim import MetatomicModel + + model = MetatomicModel("path/to/model.pt", device="cpu") + +The wrapper detects the model's dtype and supported devices +automatically. Pass ``device="cuda"`` to run on GPU. + +Build a simulation state +------------------------ + +TorchSim works with ``SimState`` objects. Convert ASE ``Atoms`` using +``torch_sim.initialize_state``: + +.. code-block:: python + + import ase.build + import torch_sim as ts + + atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) + sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + +Evaluate the model +------------------ + +Call the model on the simulation state to get energies, forces, and +stresses: + +.. code-block:: python + + results = model(sim_state) + + print("Energy:", results["energy"]) # shape [1] + print("Forces:", results["forces"]) # shape [n_atoms, 3] + print("Stress:", results["stress"]) # shape [1, 3, 3] + +Run NVE dynamics +---------------- + +Use TorchSim's Velocity Verlet integrator: + +.. code-block:: python + + from torch_sim.integrators import VelocityVerletIntegrator + + integrator = VelocityVerletIntegrator( + model=model, + state=sim_state, + dt=1.0, # femtoseconds + ) + + for step in range(100): + sim_state = integrator.step(sim_state) + if step % 10 == 0: + energy = model(sim_state)["energy"].item() + print(f"Step {step:3d} E = {energy:.4f} eV") + +The total energy should remain approximately constant in an NVE +simulation, which serves as a basic sanity check for your model. + +Next steps +---------- + +- :ref:`torchsim-batched` explains running multiple systems at once diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 8ff977dcf..0c5f0c5ad 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -65,3 +65,10 @@ API documentation .. autoclass:: metatomic_torchsim.MetatomicModel :show-inheritance: :members: + +.. toctree:: + :maxdepth: 2 + :caption: Tutorials + + torch-sim-getting-started + torch-sim-batched From 3c4e6f36475c960cf3bfb82cfc6ef4e5b7269421 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 25 Mar 2026 15:42:03 +0100 Subject: [PATCH 17/23] docs(torchsim): convert tutorials to sphinx-gallery .py files Replace plain RST tutorials with sphinx-gallery Python scripts that are self-contained and runnable in the docs build environment. Each tutorial exports a minimal constant-energy model and demonstrates the MetatomicModel / TorchSim workflow end-to-end. --- docs/src/engines/torch-sim-batched.rst | 78 ------- .../src/engines/torch-sim-getting-started.rst | 85 -------- docs/src/engines/torch-sim.rst | 4 +- python/examples/5-torchsim-getting-started.py | 204 ++++++++++++++++++ python/examples/6-torchsim-batched.py | 175 +++++++++++++++ 5 files changed, 381 insertions(+), 165 deletions(-) delete mode 100644 docs/src/engines/torch-sim-batched.rst delete mode 100644 docs/src/engines/torch-sim-getting-started.rst create mode 100644 python/examples/5-torchsim-getting-started.py create mode 100644 python/examples/6-torchsim-batched.py diff --git a/docs/src/engines/torch-sim-batched.rst b/docs/src/engines/torch-sim-batched.rst deleted file mode 100644 index 01a25d548..000000000 --- a/docs/src/engines/torch-sim-batched.rst +++ /dev/null @@ -1,78 +0,0 @@ -.. _torchsim-batched: - -Batched simulations -=================== - -TorchSim supports batching multiple systems into a single ``SimState`` -for efficient parallel evaluation on GPU. ``MetatomicModel`` handles -this transparently. - -Creating a batched state ------------------------- - -Pass a list of ASE ``Atoms`` objects to ``initialize_state``: - -.. code-block:: python - - import ase.build - import torch_sim as ts - from metatomic_torchsim import MetatomicModel - - model = MetatomicModel("model.pt", device="cpu") - - atoms_list = [ - ase.build.bulk("Cu", "fcc", a=3.6, cubic=True), - ase.build.bulk("Ni", "fcc", a=3.52, cubic=True), - ase.build.bulk("Al", "fcc", a=4.05, cubic=True), - ] - - sim_state = ts.initialize_state(atoms_list, device=model.device, dtype=model.dtype) - -Evaluating the batch --------------------- - -A single forward call evaluates all systems: - -.. code-block:: python - - results = model(sim_state) - -The output shapes reflect the batch: - -- ``results["energy"]`` has shape ``[n_systems]`` (one energy per system) -- ``results["forces"]`` has shape ``[n_total_atoms, 3]`` (all atoms - concatenated) -- ``results["stress"]`` has shape ``[n_systems, 3, 3]`` (one 3x3 tensor - per system) - -How system_idx works --------------------- - -``SimState`` tracks which atom belongs to which system via the -``system_idx`` tensor. For three 4-atom systems, ``system_idx`` looks -like:: - - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] - -``MetatomicModel.forward`` uses this to split the batched positions and -types into per-system ``System`` objects before calling the underlying -model. - -Batch consistency ------------------ - -Energies computed in a batch match those computed individually. This is -guaranteed because each system gets its own neighbor list and -independent evaluation. - -Performance considerations --------------------------- - -Batching is most beneficial on GPU, where the neighbor list computation -and model forward pass can run in parallel across systems. On CPU, the -speedup comes from reduced Python overhead (one call instead of N). - -For very large systems or many small ones, adjust the batch size to fit -in GPU memory. TorchSim does not impose a maximum batch size, but each -system gets its own neighbor list, so memory scales with the sum of -per-system sizes. diff --git a/docs/src/engines/torch-sim-getting-started.rst b/docs/src/engines/torch-sim-getting-started.rst deleted file mode 100644 index afb085b58..000000000 --- a/docs/src/engines/torch-sim-getting-started.rst +++ /dev/null @@ -1,85 +0,0 @@ -.. _torchsim-getting-started: - -Getting started -=============== - -This tutorial walks through running a short NVE molecular dynamics -simulation with a metatomic model and TorchSim. - -Prerequisites -------------- - -Install the package and its dependencies: - -.. code-block:: bash - - pip install metatomic-torchsim - -Load the model --------------- - -.. code-block:: python - - from metatomic_torchsim import MetatomicModel - - model = MetatomicModel("path/to/model.pt", device="cpu") - -The wrapper detects the model's dtype and supported devices -automatically. Pass ``device="cuda"`` to run on GPU. - -Build a simulation state ------------------------- - -TorchSim works with ``SimState`` objects. Convert ASE ``Atoms`` using -``torch_sim.initialize_state``: - -.. code-block:: python - - import ase.build - import torch_sim as ts - - atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) - sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) - -Evaluate the model ------------------- - -Call the model on the simulation state to get energies, forces, and -stresses: - -.. code-block:: python - - results = model(sim_state) - - print("Energy:", results["energy"]) # shape [1] - print("Forces:", results["forces"]) # shape [n_atoms, 3] - print("Stress:", results["stress"]) # shape [1, 3, 3] - -Run NVE dynamics ----------------- - -Use TorchSim's Velocity Verlet integrator: - -.. code-block:: python - - from torch_sim.integrators import VelocityVerletIntegrator - - integrator = VelocityVerletIntegrator( - model=model, - state=sim_state, - dt=1.0, # femtoseconds - ) - - for step in range(100): - sim_state = integrator.step(sim_state) - if step % 10 == 0: - energy = model(sim_state)["energy"].item() - print(f"Step {step:3d} E = {energy:.4f} eV") - -The total energy should remain approximately constant in an NVE -simulation, which serves as a basic sanity check for your model. - -Next steps ----------- - -- :ref:`torchsim-batched` explains running multiple systems at once diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 0c5f0c5ad..1586d0173 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -70,5 +70,5 @@ API documentation :maxdepth: 2 :caption: Tutorials - torch-sim-getting-started - torch-sim-batched + ../examples/5-torchsim-getting-started + ../examples/6-torchsim-batched diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py new file mode 100644 index 000000000..8086e5b1a --- /dev/null +++ b/python/examples/5-torchsim-getting-started.py @@ -0,0 +1,204 @@ +""" +.. _torchsim-getting-started: + +Getting started with TorchSim +============================= + +This tutorial walks through running a short NVE molecular dynamics +simulation with a metatomic model and `TorchSim +`_. +""" + +# %% +# +# Prerequisites +# ------------- +# +# Install the integration package and its dependencies: +# +# .. code-block:: bash +# +# pip install metatomic-torchsim +# +# We start by importing the modules we need: + +import os +import tempfile +from typing import Dict, List, Optional + +import ase.build +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +import metatomic.torch as mta +from metatomic_torchsim import MetatomicModel + + +# sphinx_gallery_thumbnail_number = 2 + +# %% +# +# Export a simple model +# --------------------- +# +# For this tutorial we create and export a minimal model that predicts a +# constant per-atom energy. In practice you would use a pre-trained model +# loaded from a file. + + +class ConstantEnergy(torch.nn.Module): + """A minimal model that assigns a constant energy to each atom.""" + + def __init__(self, energy_per_atom: float = -1.0): + super().__init__() + self.energy_per_atom = energy_per_atom + + def forward( + self, + systems: List[mta.System], + outputs: Dict[str, mta.ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + energies = [] + for system in systems: + n_atoms = len(system) + energies.append(self.energy_per_atom * n_atoms) + + energy = torch.tensor(energies, dtype=systems[0].positions.dtype).reshape(-1, 1) + + block = TensorBlock( + values=energy, + samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + return { + "energy": TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block]) + } + + +# %% +# +# Export the model to a file so that ``MetatomicModel`` can load it: + +tmpdir = tempfile.mkdtemp() +model_path = os.path.join(tmpdir, "constant-energy.pt") + +raw_model = ConstantEnergy(energy_per_atom=-1.5) +capabilities = mta.ModelCapabilities( + length_unit="Angstrom", + atomic_types=[14], # Silicon + interaction_range=0.0, + outputs={"energy": mta.ModelOutput(quantity="energy", unit="eV")}, + supported_devices=["cpu"], + dtype="float64", +) + +atomistic_model = mta.AtomisticModel( + raw_model.eval(), mta.ModelMetadata(), capabilities +) +atomistic_model.save(model_path) + +# %% +# +# Load the model +# -------------- +# +# Wrap the exported model with :py:class:`~metatomic_torchsim.MetatomicModel`: + +model = MetatomicModel(model_path, device="cpu") + +# %% +# +# The wrapper detects the model's dtype and supported devices +# automatically. Pass ``device="cuda"`` to run on GPU when available. + +print("dtype:", model.dtype) +print("device:", model.device) + +# %% +# +# Build a simulation state +# ------------------------ +# +# TorchSim works with ``SimState`` objects. Convert ASE ``Atoms`` using +# ``torch_sim.initialize_state``: + +import torch_sim as ts # noqa: E402 + + +atoms = ase.build.bulk("Si", "diamond", a=5.43, cubic=True) +sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + +print("Number of atoms:", sim_state.n_atoms) + +# %% +# +# Evaluate the model +# ------------------ +# +# Call the model on the simulation state to get energies, forces, and +# stresses: + +results = model(sim_state) + +print("Energy:", results["energy"]) # shape [1] +print("Forces shape:", results["forces"].shape) # shape [n_atoms, 3] +print("Stress shape:", results["stress"].shape) # shape [1, 3, 3] + +# %% +# +# Run NVE dynamics +# ---------------- +# +# Use TorchSim's Velocity Verlet integrator to run a short NVE trajectory: + +import matplotlib.pyplot as plt # noqa: E402 + + +sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + +# Initialize with small random velocities +sim_state.velocities = 0.001 * torch.randn_like(sim_state.velocities) + +energies = [] +steps = [] + +integrator = ts.integrators.VelocityVerletIntegrator(dt=1.0) + +for step in range(50): + sim_state = integrator.step(sim_state, model) + energy = results["energy"].item() + energies.append(energy) + steps.append(step) + +plt.plot(steps, energies) +plt.xlabel("Step") +plt.ylabel("Energy (eV)") +plt.title("NVE dynamics -- energy vs step") +plt.tight_layout() +plt.show() + + +# %% +# +# .. note:: +# +# With a real interatomic potential the total energy would stay approximately +# constant in an NVE simulation, which serves as a basic sanity check. +# +# Next steps +# ---------- +# +# - :ref:`torchsim-batched` explains running multiple systems at once + +# %% +# +# .. rst-class:: sphx-glr-script-out +# +# Cleanup the temporary model file: + +import shutil # noqa: E402 + + +shutil.rmtree(tmpdir) diff --git a/python/examples/6-torchsim-batched.py b/python/examples/6-torchsim-batched.py new file mode 100644 index 000000000..8f2d40614 --- /dev/null +++ b/python/examples/6-torchsim-batched.py @@ -0,0 +1,175 @@ +""" +.. _torchsim-batched: + +Batched simulations with TorchSim +================================= + +TorchSim supports batching multiple systems into a single ``SimState`` +for efficient parallel evaluation on GPU. +:py:class:`~metatomic_torchsim.MetatomicModel` handles this +transparently. +""" + +# %% +# +# Setup +# ----- +# +# We reuse the same minimal model from :ref:`torchsim-getting-started`. + +import os +import shutil +import tempfile +from typing import Dict, List, Optional + +import ase.build +import torch +import torch_sim as ts +from metatensor.torch import Labels, TensorBlock, TensorMap + +import metatomic.torch as mta +from metatomic_torchsim import MetatomicModel + + +class ConstantEnergy(torch.nn.Module): + """Assigns a constant energy per atom.""" + + def __init__(self, energy_per_atom: float = -1.0): + super().__init__() + self.energy_per_atom = energy_per_atom + + def forward( + self, + systems: List[mta.System], + outputs: Dict[str, mta.ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + energies = [] + for system in systems: + energies.append(self.energy_per_atom * len(system)) + + energy = torch.tensor(energies, dtype=systems[0].positions.dtype).reshape(-1, 1) + block = TensorBlock( + values=energy, + samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), + components=[], + properties=Labels("energy", torch.tensor([[0]])), + ) + return { + "energy": TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block]) + } + + +tmpdir = tempfile.mkdtemp() +model_path = os.path.join(tmpdir, "constant-energy.pt") + +capabilities = mta.ModelCapabilities( + length_unit="Angstrom", + atomic_types=[13, 29], # Al, Cu + interaction_range=0.0, + outputs={"energy": mta.ModelOutput(quantity="energy", unit="eV")}, + supported_devices=["cpu"], + dtype="float64", +) + +atomistic_model = mta.AtomisticModel( + ConstantEnergy(-1.5).eval(), mta.ModelMetadata(), capabilities +) +atomistic_model.save(model_path) + +model = MetatomicModel(model_path, device="cpu") + + +# %% +# +# Creating a batched state +# ------------------------ +# +# Pass a list of ASE ``Atoms`` objects to ``initialize_state``: + +atoms_list = [ + ase.build.bulk("Cu", "fcc", a=3.6, cubic=True), + ase.build.bulk("Cu", "fcc", a=3.65, cubic=True), + ase.build.bulk("Al", "fcc", a=4.05, cubic=True), +] + +sim_state = ts.initialize_state(atoms_list, device=model.device, dtype=model.dtype) +print("Total atoms in batch:", sim_state.n_atoms) + +# %% +# +# Evaluating the batch +# -------------------- +# +# A single forward call evaluates all systems: + +results = model(sim_state) + +print("Energy shape:", results["energy"].shape) # [n_systems] +print("Forces shape:", results["forces"].shape) # [n_total_atoms, 3] +print("Stress shape:", results["stress"].shape) # [n_systems, 3, 3] + +# %% +# +# The output shapes reflect the batch: +# +# - ``results["energy"]`` has shape ``[n_systems]`` -- one energy per system +# - ``results["forces"]`` has shape ``[n_total_atoms, 3]`` -- all atoms +# concatenated +# - ``results["stress"]`` has shape ``[n_systems, 3, 3]`` -- one 3x3 tensor +# per system + +print("Per-system energies:", results["energy"]) + +# %% +# +# How ``system_idx`` works +# ------------------------ +# +# ``SimState`` tracks which atom belongs to which system via the +# ``system_idx`` tensor. For three 4-atom systems, ``system_idx`` looks +# like: + +print("system_idx:", sim_state.system_idx) + +# %% +# +# ``MetatomicModel.forward`` uses this to split the batched positions and +# types into per-system ``System`` objects before calling the underlying +# model. +# +# Batch consistency +# ----------------- +# +# Energies computed in a batch match those computed individually. +# This is guaranteed because each system gets its own neighbor list and +# independent evaluation: + +individual_energies = [] +for atoms in atoms_list: + state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) + res = model(state) + individual_energies.append(res["energy"].item()) + +print("Batched: ", [e.item() for e in results["energy"]]) +print("Individual:", individual_energies) + +# %% +# +# Performance considerations +# -------------------------- +# +# Batching is most beneficial on GPU, where the neighbor list computation +# and model forward pass can run in parallel across systems. On CPU, the +# speedup comes from reduced Python overhead (one call instead of N). +# +# For very large systems or many small ones, adjust the batch size to fit +# in GPU memory. TorchSim does not impose a maximum batch size, but each +# system gets its own neighbor list, so memory scales with the sum of +# per-system sizes. + +# %% +# +# Cleanup: + +shutil.rmtree(tmpdir) From 3664a09106c2802176ddec971c0c42055dec379a Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 25 Mar 2026 19:35:16 +0100 Subject: [PATCH 18/23] fix(torchsim): ...... --- docs/src/engines/torch-sim.rst | 5 +++- python/examples/5-torchsim-getting-started.py | 25 ++++--------------- .../metatomic_torchsim/_model.py | 4 +-- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index 1586d0173..d72c678ed 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -38,7 +38,10 @@ stresses are derived via autograd by default. The wrapper also supports: :py:attr:`~metatomic_torchsim.MetatomicModel.additional_outputs` attribute See the :py:class:`~metatomic_torchsim.MetatomicModel` API documentation below -for details on all parameters. +for details on all parameters, and the tutorials for worked examples: + +- :ref:`torchsim-getting-started` -- loading a model and running NVE dynamics +- :ref:`torchsim-batched` -- evaluating multiple systems in a single call How to use the code ^^^^^^^^^^^^^^^^^^^ diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py index 8086e5b1a..335ad3821 100644 --- a/python/examples/5-torchsim-getting-started.py +++ b/python/examples/5-torchsim-getting-started.py @@ -22,8 +22,6 @@ # # We start by importing the modules we need: -import os -import tempfile from typing import Dict, List, Optional import ase.build @@ -79,10 +77,7 @@ def forward( # %% # -# Export the model to a file so that ``MetatomicModel`` can load it: - -tmpdir = tempfile.mkdtemp() -model_path = os.path.join(tmpdir, "constant-energy.pt") +# Build an ``AtomisticModel`` wrapping the raw module: raw_model = ConstantEnergy(energy_per_atom=-1.5) capabilities = mta.ModelCapabilities( @@ -97,16 +92,17 @@ def forward( atomistic_model = mta.AtomisticModel( raw_model.eval(), mta.ModelMetadata(), capabilities ) -atomistic_model.save(model_path) # %% # # Load the model # -------------- # -# Wrap the exported model with :py:class:`~metatomic_torchsim.MetatomicModel`: +# Wrap the model with :py:class:`~metatomic_torchsim.MetatomicModel`. +# You can pass an ``AtomisticModel`` directly, or a path to a saved +# ``.pt`` file: -model = MetatomicModel(model_path, device="cpu") +model = MetatomicModel(atomistic_model, device="cpu") # %% # @@ -191,14 +187,3 @@ def forward( # ---------- # # - :ref:`torchsim-batched` explains running multiple systems at once - -# %% -# -# .. rst-class:: sphx-glr-script-out -# -# Cleanup the temporary model file: - -import shutil # noqa: E402 - - -shutil.rmtree(tmpdir) diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index fa434444a..428ff6b8f 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -106,9 +106,7 @@ def __init__( :param non_conservative: If ``True``, the model will be asked to compute non-conservative forces and stresses. This can afford a speed-up, potentially at the expense of physical correctness (especially in - molecular dynamics simulations). Non-conservative outputs are also - useful for multiple time stepping schemes (e.g. r-RESPA), where - different force components are evaluated at different frequencies. + molecular dynamics simulations). :param uncertainty_threshold: Threshold for per-atom energy uncertainty in eV. When the model supports ``energy_uncertainty`` with ``per_atom=True``, atoms exceeding this threshold trigger a warning. From fac08a04c0f8b2ea749f6aeee1e03fbc2e24219e Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Thu, 26 Mar 2026 14:05:20 +0100 Subject: [PATCH 19/23] fix(docs): make torchsim tutorials runnable in sphinx-gallery Use HarmonicEnergy (E = k*sum(pos^2)) instead of ConstantEnergy so autograd can compute forces/stress. Pass AtomisticModel directly to MetatomicModel to avoid TorchScript save/load. Fix NVE loop to re-evaluate model each step. --- python/examples/5-torchsim-getting-started.py | 30 ++++++++-------- python/examples/6-torchsim-batched.py | 34 +++++++------------ 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py index 335ad3821..7a390b796 100644 --- a/python/examples/5-torchsim-getting-started.py +++ b/python/examples/5-torchsim-getting-started.py @@ -39,17 +39,18 @@ # Export a simple model # --------------------- # -# For this tutorial we create and export a minimal model that predicts a -# constant per-atom energy. In practice you would use a pre-trained model -# loaded from a file. +# For this tutorial we create and export a minimal model that predicts +# energy as a (trivial) function of atomic positions. The energy must +# depend on positions so that forces can be computed via autograd. +# In practice you would use a pre-trained model loaded from a file. -class ConstantEnergy(torch.nn.Module): - """A minimal model that assigns a constant energy to each atom.""" +class HarmonicEnergy(torch.nn.Module): + """A minimal model: harmonic restraint around initial positions.""" - def __init__(self, energy_per_atom: float = -1.0): + def __init__(self, k: float = 0.1): super().__init__() - self.energy_per_atom = energy_per_atom + self.k = k def forward( self, @@ -57,12 +58,13 @@ def forward( outputs: Dict[str, mta.ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: - energies = [] + energies: List[torch.Tensor] = [] for system in systems: - n_atoms = len(system) - energies.append(self.energy_per_atom * n_atoms) + # energy = k * sum(positions^2) -- differentiable w.r.t. positions + e = self.k * torch.sum(system.positions**2) + energies.append(e.reshape(1, 1)) - energy = torch.tensor(energies, dtype=systems[0].positions.dtype).reshape(-1, 1) + energy = torch.cat(energies, dim=0) block = TensorBlock( values=energy, @@ -79,7 +81,7 @@ def forward( # # Build an ``AtomisticModel`` wrapping the raw module: -raw_model = ConstantEnergy(energy_per_atom=-1.5) +raw_model = HarmonicEnergy(k=0.1) capabilities = mta.ModelCapabilities( length_unit="Angstrom", atomic_types=[14], # Silicon @@ -164,8 +166,8 @@ def forward( for step in range(50): sim_state = integrator.step(sim_state, model) - energy = results["energy"].item() - energies.append(energy) + step_results = model(sim_state) + energies.append(step_results["energy"].item()) steps.append(step) plt.plot(steps, energies) diff --git a/python/examples/6-torchsim-batched.py b/python/examples/6-torchsim-batched.py index 8f2d40614..4d7e71dd1 100644 --- a/python/examples/6-torchsim-batched.py +++ b/python/examples/6-torchsim-batched.py @@ -16,10 +16,9 @@ # ----- # # We reuse the same minimal model from :ref:`torchsim-getting-started`. +# The model must produce differentiable energy so that forces/stress can +# be computed via autograd. -import os -import shutil -import tempfile from typing import Dict, List, Optional import ase.build @@ -31,12 +30,12 @@ from metatomic_torchsim import MetatomicModel -class ConstantEnergy(torch.nn.Module): - """Assigns a constant energy per atom.""" +class HarmonicEnergy(torch.nn.Module): + """Harmonic restraint: E = k * sum(positions^2).""" - def __init__(self, energy_per_atom: float = -1.0): + def __init__(self, k: float = 0.1): super().__init__() - self.energy_per_atom = energy_per_atom + self.k = k def forward( self, @@ -44,11 +43,12 @@ def forward( outputs: Dict[str, mta.ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: - energies = [] + energies: List[torch.Tensor] = [] for system in systems: - energies.append(self.energy_per_atom * len(system)) + e = self.k * torch.sum(system.positions**2) + energies.append(e.reshape(1, 1)) - energy = torch.tensor(energies, dtype=systems[0].positions.dtype).reshape(-1, 1) + energy = torch.cat(energies, dim=0) block = TensorBlock( values=energy, samples=Labels("system", torch.arange(len(systems)).reshape(-1, 1)), @@ -60,9 +60,6 @@ def forward( } -tmpdir = tempfile.mkdtemp() -model_path = os.path.join(tmpdir, "constant-energy.pt") - capabilities = mta.ModelCapabilities( length_unit="Angstrom", atomic_types=[13, 29], # Al, Cu @@ -73,11 +70,10 @@ def forward( ) atomistic_model = mta.AtomisticModel( - ConstantEnergy(-1.5).eval(), mta.ModelMetadata(), capabilities + HarmonicEnergy(0.1).eval(), mta.ModelMetadata(), capabilities ) -atomistic_model.save(model_path) -model = MetatomicModel(model_path, device="cpu") +model = MetatomicModel(atomistic_model, device="cpu") # %% @@ -167,9 +163,3 @@ def forward( # in GPU memory. TorchSim does not impose a maximum batch size, but each # system gets its own neighbor list, so memory scales with the sum of # per-system sizes. - -# %% -# -# Cleanup: - -shutil.rmtree(tmpdir) From c80d5d150a3e8108c2422a9f7ce43ce5116ecaf7 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Thu, 26 Mar 2026 15:14:48 +0100 Subject: [PATCH 20/23] fix(docs): remove velocities assignment, SimState uses momenta SimState has no velocities attribute; the integrator manages momenta internally. Remove the manual velocity initialization. --- python/examples/5-torchsim-getting-started.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py index 7a390b796..bf2352169 100644 --- a/python/examples/5-torchsim-getting-started.py +++ b/python/examples/5-torchsim-getting-started.py @@ -149,16 +149,14 @@ def forward( # Run NVE dynamics # ---------------- # -# Use TorchSim's Velocity Verlet integrator to run a short NVE trajectory: +# Use TorchSim's Velocity Verlet integrator to run a short NVE trajectory. +# The integrator manages momenta internally via ``SimState``: import matplotlib.pyplot as plt # noqa: E402 sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) -# Initialize with small random velocities -sim_state.velocities = 0.001 * torch.randn_like(sim_state.velocities) - energies = [] steps = [] From ef449e71cd2ffa4a5a019ebb5b2e0b7c8ebd9d7b Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Thu, 26 Mar 2026 15:15:28 +0100 Subject: [PATCH 21/23] fix(docs): use nve_init/nve_step functional API for NVE tutorial SimState has no velocities; use nve_init to create MDState with Maxwell-Boltzmann momenta, then nve_step for integration. --- python/examples/5-torchsim-getting-started.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py index bf2352169..9e7ec2fa6 100644 --- a/python/examples/5-torchsim-getting-started.py +++ b/python/examples/5-torchsim-getting-started.py @@ -149,29 +149,33 @@ def forward( # Run NVE dynamics # ---------------- # -# Use TorchSim's Velocity Verlet integrator to run a short NVE trajectory. -# The integrator manages momenta internally via ``SimState``: +# Use TorchSim's NVE (Velocity Verlet) integrator to run a short trajectory. +# ``nve_init`` samples momenta from a Maxwell-Boltzmann distribution at the +# given temperature, and ``nve_step`` advances by one timestep: import matplotlib.pyplot as plt # noqa: E402 - +from torch_sim.integrators import nve_init, nve_step # noqa: E402 +from torch_sim.units import MetalUnits # noqa: E402 sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) +# Initialize NVE state with momenta at 300 K (in eV units) +kT = 300.0 * MetalUnits.temperature # kelvin -> eV +md_state = nve_init(sim_state, model, kT=kT) + energies = [] steps = [] - -integrator = ts.integrators.VelocityVerletIntegrator(dt=1.0) +dt = 1.0 # femtoseconds for step in range(50): - sim_state = integrator.step(sim_state, model) - step_results = model(sim_state) - energies.append(step_results["energy"].item()) + md_state = nve_step(md_state, model, dt=dt) + energies.append(md_state.energy.sum().item()) steps.append(step) plt.plot(steps, energies) plt.xlabel("Step") -plt.ylabel("Energy (eV)") -plt.title("NVE dynamics -- energy vs step") +plt.ylabel("Potential energy (eV)") +plt.title("NVE dynamics -- potential energy vs step") plt.tight_layout() plt.show() From d757810584a707553b043e7b6ff759a9b4bafb4a Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Thu, 26 Mar 2026 15:30:52 +0100 Subject: [PATCH 22/23] fix(docs): sort imports to satisfy ruff I001 --- python/examples/5-torchsim-getting-started.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py index 9e7ec2fa6..f5ec0e5d4 100644 --- a/python/examples/5-torchsim-getting-started.py +++ b/python/examples/5-torchsim-getting-started.py @@ -157,6 +157,7 @@ def forward( from torch_sim.integrators import nve_init, nve_step # noqa: E402 from torch_sim.units import MetalUnits # noqa: E402 + sim_state = ts.initialize_state(atoms, device=model.device, dtype=model.dtype) # Initialize NVE state with momenta at 300 K (in eV units) From bbaa7e4371b9e5cfa1c3de0cea2ea01e72134a39 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Thu, 26 Mar 2026 16:20:29 +0100 Subject: [PATCH 23/23] Small doc tweaks --- docs/src/engines/torch-sim.rst | 11 ++--------- python/examples/5-torchsim-getting-started.py | 2 -- python/examples/6-torchsim-batched.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/src/engines/torch-sim.rst b/docs/src/engines/torch-sim.rst index d72c678ed..a37d2d667 100644 --- a/docs/src/engines/torch-sim.rst +++ b/docs/src/engines/torch-sim.rst @@ -1,7 +1,7 @@ .. _engine-torch-sim: -torch-sim -========= +TorchSim +======== .. list-table:: :header-rows: 1 @@ -68,10 +68,3 @@ API documentation .. autoclass:: metatomic_torchsim.MetatomicModel :show-inheritance: :members: - -.. toctree:: - :maxdepth: 2 - :caption: Tutorials - - ../examples/5-torchsim-getting-started - ../examples/6-torchsim-batched diff --git a/python/examples/5-torchsim-getting-started.py b/python/examples/5-torchsim-getting-started.py index f5ec0e5d4..f39a5f10e 100644 --- a/python/examples/5-torchsim-getting-started.py +++ b/python/examples/5-torchsim-getting-started.py @@ -32,8 +32,6 @@ from metatomic_torchsim import MetatomicModel -# sphinx_gallery_thumbnail_number = 2 - # %% # # Export a simple model diff --git a/python/examples/6-torchsim-batched.py b/python/examples/6-torchsim-batched.py index 4d7e71dd1..6a2e4730d 100644 --- a/python/examples/6-torchsim-batched.py +++ b/python/examples/6-torchsim-batched.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional import ase.build +import matplotlib.pyplot as plt import torch import torch_sim as ts from metatensor.torch import Labels, TensorBlock, TensorMap @@ -150,6 +151,16 @@ def forward( print("Batched: ", [e.item() for e in results["energy"]]) print("Individual:", individual_energies) +plt.scatter(individual_energies, results["energy"].cpu().numpy()) +plt.plot( + [min(individual_energies), max(individual_energies)], + [min(individual_energies), max(individual_energies)], + "k--", +) +plt.xlabel("Individual energies") +plt.ylabel("Batched energies") +plt.show() + # %% # # Performance considerations