diff --git a/pyproject.toml b/pyproject.toml index d6c102e03..90de3c4a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "h5py>=3.15.1", "numpy>=1.26,<3; python_version < '3.13'", "numpy>=2.3.2,<3; python_version >= '3.13'", - "nvalchemi-toolkit-ops>=0.2.0", + "nvalchemi-toolkit-ops>=0.3.0", "tables>=3.11.1", "torch>=2", "tqdm>=4.67", @@ -38,7 +38,7 @@ dependencies = [ [project.optional-dependencies] test = [ - "torch-sim-atomistic[io,symmetry]", + "torch-sim-atomistic[io,symmetry,vesin]", "platformdirs>=4.0.0", "psutil>=7.0.0", "pymatgen>=2025.6.14", @@ -47,6 +47,7 @@ test = [ "spglib>=2.6", "vesin[torch]>=0.5.3", ] +vesin = ["vesin[torch]>=0.5.3"] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] symmetry = ["moyopy>=0.7.8"] mace = ["mace-torch>=0.3.15"] diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index ea5076d7a..d1488e6bc 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -34,7 +34,7 @@ def make_structure(name: str, repeats: int = REPEATS) -> Atoms: "fcc": lambda: bulk("Cu", "fcc", a=3.6), "hcp": lambda: bulk("Ti", "hcp", a=2.95, c=4.68), "diamond": lambda: bulk("Si", "diamond", a=5.43), - "bcc": lambda: bulk("Al", "bcc", a=2 / np.sqrt(3), cubic=True), + "bcc": lambda: bulk("Al", "bcc", a=4 / np.sqrt(3), cubic=True), "p6bar": lambda: crystal( "Si", [(0.3, 0.1, 0.25)], @@ -128,7 +128,7 @@ def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tenso @pytest.fixture def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper: """LJ model with noise added to forces/stress.""" - return NoisyModelWrapper(model) + return NoisyModelWrapper(model, noise_scale=5e-1, concentration=1.0) @pytest.fixture diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index b96c26f24..c4e8ae0f8 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -1,9 +1,8 @@ -import time +import re from collections.abc import Callable -from typing import Any, cast +from typing import Any import numpy as np -import psutil import pytest import torch from ase import Atoms @@ -60,43 +59,62 @@ def ase_to_torch_batch( ) -# Adapted from torch_nl test -# https://github.com/felixmusil/torch_nl/blob/main/torch_nl/test_nl.py - -# triclinic atomic structure -CaCrP2O7_mvc_11955_symmetrized = { - "positions": [ - [3.68954016, 5.03568186, 4.64369552], - [5.12301681, 2.13482791, 2.66220405], - [1.99411973, 0.94691001, 1.25068234], - [6.81843724, 6.22359976, 6.05521724], - [2.63005662, 4.16863452, 0.86090529], - [6.18250036, 3.00187525, 6.44499428], - [2.11497733, 1.98032773, 4.53610884], - [6.69757964, 5.19018203, 2.76979073], - [1.39215545, 2.94386142, 5.60917746], - [7.42040152, 4.22664834, 1.69672212], - [2.43224207, 5.4571615, 6.70305327], - [6.3803149, 1.71334827, 0.6028463], - [1.11265639, 1.50166318, 3.48760997], - [7.69990058, 5.66884659, 3.8182896], - [3.56971588, 5.20836551, 1.43673437], - [5.2428411, 1.96214426, 5.8691652], - [3.12282634, 2.72812741, 1.05450432], - [5.68973063, 4.44238236, 6.25139525], - [3.24868468, 2.83997522, 3.99842386], - [5.56387229, 4.33053455, 3.30747571], - [2.60835346, 0.74421609, 5.3236629], - [6.20420351, 6.42629368, 1.98223667], - ], - "cell": [ - [6.19330899, 0.0, 0.0], - [2.4074486111396207, 6.149627748674982, 0.0], - [0.2117993724186579, 1.0208820183960539, 7.305899571570074], - ], - "numbers": [*[20] * 2, *[24] * 2, *[15] * 4, *[8] * 14], - "pbc": [True, True, True], -} +def _make_triclinic_atoms() -> Atoms: + """CaCrP2O7 (mvc-11955) triclinic structure. + + Adapted from https://github.com/felixmusil/torch_nl/blob/main/torch_nl/test_nl.py + """ + return Atoms( + positions=[ + [3.68954016, 5.03568186, 4.64369552], + [5.12301681, 2.13482791, 2.66220405], + [1.99411973, 0.94691001, 1.25068234], + [6.81843724, 6.22359976, 6.05521724], + [2.63005662, 4.16863452, 0.86090529], + [6.18250036, 3.00187525, 6.44499428], + [2.11497733, 1.98032773, 4.53610884], + [6.69757964, 5.19018203, 2.76979073], + [1.39215545, 2.94386142, 5.60917746], + [7.42040152, 4.22664834, 1.69672212], + [2.43224207, 5.4571615, 6.70305327], + [6.3803149, 1.71334827, 0.6028463], + [1.11265639, 1.50166318, 3.48760997], + [7.69990058, 5.66884659, 3.8182896], + [3.56971588, 5.20836551, 1.43673437], + [5.2428411, 1.96214426, 5.8691652], + [3.12282634, 2.72812741, 1.05450432], + [5.68973063, 4.44238236, 6.25139525], + [3.24868468, 2.83997522, 3.99842386], + [5.56387229, 4.33053455, 3.30747571], + [2.60835346, 0.74421609, 5.3236629], + [6.20420351, 6.42629368, 1.98223667], + ], + cell=[ + [6.19330899, 0.0, 0.0], + [2.4074486111396207, 6.149627748674982, 0.0], + [0.2117993724186579, 1.0208820183960539, 7.305899571570074], + ], + numbers=[*[20] * 2, *[24] * 2, *[15] * 4, *[8] * 14], + pbc=[True, True, True], + ) + + +def _make_very_skewed_atoms() -> Atoms: + """Bi rhombohedral α=10° — extremely skewed, triggers nvalchemiops overflow.""" + atoms = bulk("Bi", "rhombohedral", a=6, alpha=10) + atoms.info["very_skewed"] = True + return atoms + + +@pytest.fixture +def periodic_atoms_unwrap_subset() -> list[Atoms]: + """Fully periodic crystals used to test invariance under lattice translations.""" + return [ + bulk("Si", "diamond", a=6, cubic=True), + bulk("Cu", "fcc", a=3.6), + bulk("Si", "diamond", a=6), + _make_triclinic_atoms(), + ] @pytest.fixture @@ -107,14 +125,11 @@ def periodic_atoms_set(): bulk("Cu", "fcc", a=3.6), bulk("Si", "bct", a=6, c=3), bulk("Ti", "hcp", a=2.94, c=4.64, orthorhombic=False), - # test very skewed rhombohedral cells bulk("Bi", "rhombohedral", a=6, alpha=20), - bulk( - "Bi", "rhombohedral", a=6, alpha=10 - ), # very skewed, by far the slowest test case bulk("SiCu", "rocksalt", a=6), bulk("SiFCu", "fluorite", a=6), - Atoms(**CaCrP2O7_mvc_11955_symmetrized), + _make_triclinic_atoms(), + _make_very_skewed_atoms(), ] @@ -125,23 +140,158 @@ def molecule_atoms_set() -> list: ] -@pytest.mark.parametrize("cutoff", [1, 3, 5, 7]) +def _sorted_mic_distances( + positions: torch.Tensor, + row_vector_cell: torch.Tensor, + mapping: torch.Tensor, + mapping_system: torch.Tensor, + shifts_idx: torch.Tensor, +) -> np.ndarray: + cell_shifts = transforms.compute_cell_shifts( + row_vector_cell, shifts_idx, mapping_system + ) + d = transforms.compute_distances_with_cell_shifts(positions, mapping, cell_shifts) + return np.sort(d.detach().cpu().numpy()) + + +def _integer_lattice_shift_positions( + positions: torch.Tensor, + cell_batched: torch.Tensor, + system_idx: torch.Tensor, + integers: torch.Tensor, +) -> torch.Tensor: + cell_per_atom = cell_batched[system_idx.to(cell_batched.device)] + delta = (integers.to(cell_per_atom.dtype).unsqueeze(-1) * cell_per_atom).sum(dim=1) + return positions + delta + + +def _all_nl_backends() -> list[Any]: + """All NL backends as pytest.params with skipif marks for optional deps.""" + _skip_vesin = pytest.mark.skipif( + not neighbors.VESIN_AVAILABLE, reason="Vesin is not installed" + ) + _skip_vesin_ts = pytest.mark.skipif( + not neighbors.VESIN_TORCH_AVAILABLE, reason="Vesin is not installed" + ) + + _skip_alchemiops = pytest.mark.skipif( + not neighbors.ALCHEMIOPS_AVAILABLE, reason="nvalchemiops is not installed" + ) + return [ + pytest.param(neighbors.torch_nl_n2, id="torch_nl_n2"), + pytest.param(neighbors.torch_nl_linked_cell, id="torch_nl_linked_cell"), + pytest.param(neighbors.vesin_nl, id="vesin_nl", marks=_skip_vesin), + pytest.param(neighbors.vesin_nl_ts, id="vesin_nl_ts", marks=_skip_vesin_ts), + pytest.param( + neighbors.alchemiops_nl_n2, + id="alchemiops_nl_n2", + marks=_skip_alchemiops, + ), + pytest.param( + neighbors.alchemiops_nl_cell_list, + id="alchemiops_nl_cell_list", + marks=_skip_alchemiops, + ), + ] + + +def _nl_backends_x_cutoffs(cutoffs: list[float] | None = None) -> list[Any]: + """Cross-product of all NL backends x cutoffs, preserving skip marks.""" + if cutoffs is None: + cutoffs = [1, 3, 5, 7] + return [ + pytest.param(p.values[0], c, id=f"{p.values[0].__name__}-{c}", marks=p.marks) + for p in _all_nl_backends() + for c in cutoffs + ] + + +@pytest.mark.parametrize("cutoff", [2.0, 5.0, 7.0]) @pytest.mark.parametrize("self_interaction", [True, False]) +@pytest.mark.parametrize("shift_mode", ["uniform", "per_atom"]) +@pytest.mark.parametrize("nl_implementation", _all_nl_backends()) +def test_neighbor_list_invariant_under_lattice_image_shifts( + *, + cutoff: float, + self_interaction: bool, + shift_mode: str, + nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + periodic_atoms_unwrap_subset: list[Atoms], +) -> None: + """NL backends: same sorted MIC distances and pair count after lattice-image shifts. + + ``uniform``: one integer triplet per system applied to all its atoms (rigid image). + ``per_atom``: independent integer triplet per atom (same structure mod PBC). + + Backends: ``torch_nl_n2``, ``torch_nl_linked_cell``, ``vesin_nl``, ``vesin_nl_ts``, + ``alchemiops_nl_n2``, ``alchemiops_nl_cell_list`` (latter four skip if optional + deps missing). See TorchSim/torch-sim#423, #437. + """ + atoms_list = periodic_atoms_unwrap_subset + pos_wrapped, cell_flat, pbc_flat, batch, _ = ase_to_torch_batch( + atoms_list, device=DEVICE, dtype=DTYPE + ) + n_sys = len(atoms_list) + cell_b = cell_flat.view(n_sys, 3, 3) + pbc_b = pbc_flat.view(n_sys, 3) + pbc_on_atom = pbc_b[batch] + if shift_mode == "uniform": + triplets = torch.tensor( + [[2, -1, 1], [-3, 0, 2], [1, 1, -2], [2, 2, -3]], + dtype=torch.long, + device=DEVICE, + ) + per_system = triplets[torch.arange(n_sys, device=DEVICE) % triplets.shape[0]] + ints = per_system[batch] * pbc_on_atom.long() + elif shift_mode == "per_atom": + n_atoms = pos_wrapped.shape[0] + ar = torch.arange(n_atoms, device=DEVICE, dtype=torch.long) + ints = torch.stack( + [(ar % 3) - 1, (ar % 5) - 2, (ar % 4) - 2], + dim=1, + ) + ints = ints * pbc_on_atom.long() + else: + raise AssertionError(f"unknown shift_mode: {shift_mode}") + pos_shifted = _integer_lattice_shift_positions(pos_wrapped, cell_b, batch, ints) + assert not torch.allclose(pos_shifted, pos_wrapped, rtol=0.0, atol=1e-12), ( + "expected non-trivial lattice shifts along periodic axes" + ) + c_tensor = torch.tensor(cutoff, dtype=DTYPE, device=DEVICE) + map_w, sys_w, sh_w = nl_implementation( + cutoff=c_tensor, + positions=pos_wrapped, + cell=cell_b, + pbc=pbc_b, + system_idx=batch, + self_interaction=self_interaction, + ) + map_s, sys_s, sh_s = nl_implementation( + cutoff=c_tensor, + positions=pos_shifted, + cell=cell_b, + pbc=pbc_b, + system_idx=batch, + self_interaction=self_interaction, + ) + d_w = _sorted_mic_distances(pos_wrapped, cell_b, map_w, sys_w, sh_w) + d_s = _sorted_mic_distances(pos_shifted, cell_b, map_s, sys_s, sh_s) + np.testing.assert_allclose(d_w, d_s, rtol=1e-5, atol=1e-5) + assert map_w.shape[1] == map_s.shape[1] + assert torch.equal(batch[map_w[0]], batch[map_w[1]]) + assert torch.equal(batch[map_s[0]], batch[map_s[1]]) + + @pytest.mark.parametrize( - "nl_implementation", - [neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell] - + ([neighbors.vesin_nl, neighbors.vesin_nl_ts] if neighbors.VESIN_AVAILABLE else []) - + ( - [neighbors.alchemiops_nl_n2, neighbors.alchemiops_nl_cell_list] - if neighbors.ALCHEMIOPS_AVAILABLE - else [] - ), + ("nl_implementation", "cutoff"), + _nl_backends_x_cutoffs(), ) +@pytest.mark.parametrize("self_interaction", [True, False]) def test_neighbor_list_implementations( *, + nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]], cutoff: float, self_interaction: bool, - nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]], molecule_atoms_set: list[Atoms], periodic_atoms_set: list[Atoms], ) -> None: @@ -151,13 +301,15 @@ def test_neighbor_list_implementations( systems, comparing sorted distances against ASE reference values. """ atoms_list = molecule_atoms_set + periodic_atoms_set + is_alchemiops = "alchemiops" in nl_implementation.__name__ + if is_alchemiops and cutoff >= 3: + atoms_list = [a for a in atoms_list if not a.info.get("very_skewed")] # NOTE we can't use atoms_to_state here because we want to test mixed # periodic and non-periodic systems pos, row_vector_cell, pbc, batch, _ = ase_to_torch_batch( atoms_list, device=DEVICE, dtype=DTYPE ) - mapping, mapping_system, shifts_idx = nl_implementation( cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), positions=pos, @@ -216,16 +368,7 @@ def test_neighbor_list_implementations( @pytest.mark.parametrize("self_interaction", [True, False]) @pytest.mark.parametrize("pbc_val", [True, False]) -@pytest.mark.parametrize( - "nl_implementation", - [neighbors.torch_nl_n2, neighbors.torch_nl_linked_cell] - + ([neighbors.vesin_nl, neighbors.vesin_nl_ts] if neighbors.VESIN_AVAILABLE else []) - + ( - [neighbors.alchemiops_nl_n2, neighbors.alchemiops_nl_cell_list] - if neighbors.ALCHEMIOPS_AVAILABLE and torch.cuda.is_available() - else [] - ), -) +@pytest.mark.parametrize("nl_implementation", _all_nl_backends()) def test_nl_pbc_edge_cases( *, pbc_val: bool, self_interaction: bool, nl_implementation: Callable[..., Any] ) -> None: @@ -279,34 +422,48 @@ def _minimal_neighbor_list_inputs( return positions, cell, pbc, cutoff, system_idx -def test_vesin_nl_availability() -> None: - """Test that availability flags are correctly set.""" +def test_optional_neighbor_backends_expose_flags_and_entrypoints() -> None: + """Public API: booleans and callables always present after import.""" assert isinstance(neighbors.VESIN_AVAILABLE, bool) - - assert callable(neighbors.vesin_nl) - assert callable(neighbors.vesin_nl_ts) - - if not neighbors.VESIN_AVAILABLE: - positions, cell, pbc, cutoff, system_idx = _minimal_neighbor_list_inputs(DEVICE) - with pytest.raises(ImportError, match="Vesin is not installed"): - neighbors.vesin_nl(positions, cell, pbc, cutoff, system_idx) - with pytest.raises(ImportError, match="Vesin is not installed"): - neighbors.vesin_nl_ts(positions, cell, pbc, cutoff, system_idx) + assert isinstance(neighbors.ALCHEMIOPS_AVAILABLE, bool) + for name in ( + "vesin_nl", + "vesin_nl_ts", + "alchemiops_nl_n2", + "alchemiops_nl_cell_list", + ): + assert callable(getattr(neighbors, name)) -def test_alchemiops_nl_availability() -> None: - """Test that alchemiops optional dependency flags and errors are consistent.""" - assert isinstance(neighbors.ALCHEMIOPS_AVAILABLE, bool) +@pytest.mark.parametrize( + ("fn_names", "message"), + [ + ( + ("vesin_nl", "vesin_nl_ts"), + "Vesin is not installed. Install it with: pip install vesin", + ), + ( + ("alchemiops_nl_n2", "alchemiops_nl_cell_list"), + "nvalchemiops is not installed. Install it with: pip install nvalchemiops", + ), + ], +) +def test_neighbor_list_stub_import_errors_match_documentation( + monkeypatch: pytest.MonkeyPatch, + fn_names: tuple[str, ...], + message: str, +) -> None: + """Stubs must raise the same ImportError as optional-backend fallbacks.""" - assert callable(neighbors.alchemiops_nl_n2) - assert callable(neighbors.alchemiops_nl_cell_list) + def _stub(*args: object, **kwargs: object) -> None: # noqa: ARG001 + raise ImportError(message) - if not neighbors.ALCHEMIOPS_AVAILABLE: - positions, cell, pbc, cutoff, system_idx = _minimal_neighbor_list_inputs(DEVICE) - with pytest.raises(ImportError, match="nvalchemiops is not installed"): - neighbors.alchemiops_nl_n2(positions, cell, pbc, cutoff, system_idx) - with pytest.raises(ImportError, match="nvalchemiops is not installed"): - neighbors.alchemiops_nl_cell_list(positions, cell, pbc, cutoff, system_idx) + for fn_name in fn_names: + monkeypatch.setattr(neighbors, fn_name, _stub) + args = _minimal_neighbor_list_inputs(DEVICE) + for fn_name in fn_names: + with pytest.raises(ImportError, match=re.escape(message)): + getattr(neighbors, fn_name)(*args) def test_fallback_when_alchemiops_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: @@ -474,77 +631,3 @@ def test_strict_nl_edge_cases() -> None: shifts_idx=shifts_idx, ) assert len(new_mapping[0]) > 0 # Should find neighbors - - -def test_neighbor_lists_time_and_memory() -> None: - """Test performance and memory characteristics of neighbor list implementations.""" - # Create a smaller system to reduce memory usage - n_atoms = 100 - pos = torch.rand(n_atoms, 3, device=DEVICE, dtype=DTYPE) - cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 10.0 - cutoff = torch.tensor(2.0, device=DEVICE, dtype=DTYPE) - - # Test different implementations - nl_implementations = [ - neighbors.torch_nl_n2, - neighbors.torch_nl_linked_cell, - ] - if neighbors.VESIN_AVAILABLE: - nl_implementations.extend( - [ - neighbors.vesin_nl_ts, - cast("Callable[..., Any]", neighbors.vesin_nl), - ] - ) - if neighbors.ALCHEMIOPS_AVAILABLE and DEVICE.type == "cuda": - nl_implementations.extend( - [neighbors.alchemiops_nl_n2, neighbors.alchemiops_nl_cell_list] - ) - - for nl_fn in nl_implementations: - # Get initial memory usage - process = psutil.Process() - initial_cpu_memory = process.memory_info().rss # in bytes - - if DEVICE.type == "cuda": - torch.cuda.reset_peak_memory_stats() - initial_gpu_memory = torch.cuda.memory_allocated() - - # Time the execution - start_time = time.perf_counter() - - # All neighbor list functions now use the unified API with system_idx - system_idx = torch.zeros(n_atoms, dtype=torch.long, device=DEVICE) - # Fix pbc tensor shape - pbc = torch.tensor([[True, True, True]], device=DEVICE) - _mapping, _mapping_system, _shifts_idx = nl_fn( - positions=pos, - cell=cell, - pbc=pbc, - cutoff=cutoff, - system_idx=system_idx, - self_interaction=False, - ) - - end_time = time.perf_counter() - execution_time = end_time - start_time - - # Get final memory usage - final_cpu_memory = process.memory_info().rss - cpu_memory_used = final_cpu_memory - initial_cpu_memory - fn_name = str(nl_fn) - - # Warning: cuda case was never tested, to be tweaked later - if DEVICE.type == "cuda": - final_gpu_memory = torch.cuda.memory_allocated() - gpu_memory_used = final_gpu_memory - initial_gpu_memory - assert execution_time < 0.01, f"{fn_name} took too long: {execution_time}s" - assert gpu_memory_used < 5e8, ( - f"{fn_name} used too much GPU memory: {gpu_memory_used / 1e6:.2f}MB" - ) - torch.cuda.empty_cache() - else: - assert cpu_memory_used < 5e8, ( - f"{fn_name} used too much CPU memory: {cpu_memory_used / 1e6:.2f}MB" - ) - assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s" diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e3eef6397..bac5ec985 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -355,6 +355,66 @@ def test_pbc_wrap_batched_preserves_relative_positions( assert torch.allclose(orig_vec, wrapped_vec, atol=1e-6) +def test_pbc_wrap_batched_and_get_lattice_shifts() -> None: + """Test that wrapping returns correct positions and integer lattice shifts.""" + cell1 = torch.tensor( + [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]], + dtype=torch.float64, + device=DEVICE, + ) + cell2 = torch.tensor( + [[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]], + dtype=torch.float64, + device=DEVICE, + ) + cell = torch.stack([cell1, cell2]) + pbc = torch.tensor([[True, True, True], [True, True, True]], device=DEVICE) + positions = torch.tensor( + [[4.0, -1.0, 7.5], [3.0, 2.5, 4.5]], + dtype=torch.float64, + device=DEVICE, + ) + system_idx = torch.tensor([0, 1], device=DEVICE) + wrapped, shifts = tst.pbc_wrap_batched_and_get_lattice_shifts( + positions, cell, system_idx, pbc + ) + assert wrapped.shape == positions.shape + assert shifts.shape == positions.shape + reconstructed = wrapped + (shifts.unsqueeze(-1) * cell[system_idx]).sum(dim=1) + torch.testing.assert_close(reconstructed, positions, atol=1e-10, rtol=0.0) + assert (shifts[0] != 0).any(), "expected non-zero shifts for displaced atom" + + +def test_pbc_wrap_batched_and_get_lattice_shifts_singular_cell() -> None: + """Singular cells and non-periodic systems are left unchanged.""" + cell = torch.zeros(1, 3, 3, dtype=torch.float64, device=DEVICE) + pbc = torch.tensor([[True, True, True]], device=DEVICE) + positions = torch.tensor([[5.0, 5.0, 5.0]], dtype=torch.float64, device=DEVICE) + system_idx = torch.tensor([0], device=DEVICE) + wrapped, shifts = tst.pbc_wrap_batched_and_get_lattice_shifts( + positions, cell, system_idx, pbc + ) + torch.testing.assert_close(wrapped, positions) + assert (shifts == 0).all() + + +def test_pbc_wrap_batched_and_get_lattice_shifts_non_periodic() -> None: + """Non-periodic axes should not be wrapped.""" + cell = torch.eye(3, dtype=torch.float64, device=DEVICE).unsqueeze(0) * 2.0 + pbc = torch.tensor([[True, False, True]], device=DEVICE) + positions = torch.tensor([[3.0, 5.0, 5.0]], dtype=torch.float64, device=DEVICE) + system_idx = torch.tensor([0], device=DEVICE) + wrapped, shifts = tst.pbc_wrap_batched_and_get_lattice_shifts( + positions, cell, system_idx, pbc + ) + assert wrapped[0, 0] == 1.0, "periodic x should wrap 3.0 -> 1.0 in cell=2" + assert wrapped[0, 1] == 5.0, "non-periodic y should stay at 5.0" + assert wrapped[0, 2] == 1.0, "periodic z should wrap 5.0 -> 1.0 in cell=2" + assert shifts[0, 0] == 1, "x shift should be 1 (floor(3.0/2.0))" + assert shifts[0, 1] == 0, "non-periodic y shift should be 0" + assert shifts[0, 2] == 2, "z shift should be 2 (floor(5.0/2.0))" + + def test_safe_mask_basic() -> None: """Test basic functionality of safe_mask with log function. diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 0d0741537..8aa6bb5e6 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -35,6 +35,9 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): from torch_sim.typing import MemoryScaling +VALIDATE_ATOL = 1e-4 + + class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. @@ -307,21 +310,21 @@ def validate_model_outputs( # noqa: C901, PLR0915 if stress_computed and model_output["stress"].shape != (3, 3, 3): raise ValueError(f"{model_output['stress'].shape=} != (3, 3, 3)") + # Test single Si system output shapes (8 atoms) si_state = ts.io.atoms_to_state([si_atoms], device, dtype) si_model_output = model.forward(si_state) if not torch.allclose( - si_model_output["energy"], model_output["energy"][0], atol=1e-3 + si_model_output["energy"], model_output["energy"][0], atol=VALIDATE_ATOL ): raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}") if not torch.allclose( forces := si_model_output["forces"], expected_forces := model_output["forces"][: si_state.n_atoms], - atol=1e-3, + atol=VALIDATE_ATOL, ): raise ValueError(f"{forces=} != {expected_forces=}") - # Test single Si system output shapes (8 atoms) if si_model_output["energy"].shape != (1,): raise ValueError(f"{si_model_output['energy'].shape=} != (1,)") if force_computed and si_model_output["forces"].shape != (8, 3): @@ -329,10 +332,11 @@ def validate_model_outputs( # noqa: C901, PLR0915 if stress_computed and si_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)") + # Test single Mg system output shapes (12 atoms) mg_state = ts.io.atoms_to_state([mg_atoms], device, dtype) mg_model_output = model.forward(mg_state) if not torch.allclose( - mg_model_output["energy"], model_output["energy"][1], atol=1e-3 + mg_model_output["energy"], model_output["energy"][1], atol=VALIDATE_ATOL ): raise ValueError(f"{mg_model_output['energy']=} != {model_output['energy'][1]=}") mg_n = mg_state.n_atoms @@ -340,11 +344,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 if not torch.allclose( forces := mg_model_output["forces"], expected_forces := model_output["forces"][mg_slice], - atol=1e-3, + atol=VALIDATE_ATOL, ): raise ValueError(f"{forces=} != {expected_forces=}") - # Test single Mg system output shapes (12 atoms) if mg_model_output["energy"].shape != (1,): raise ValueError(f"{mg_model_output['energy'].shape=} != (1,)") if force_computed and mg_model_output["forces"].shape != (12, 3): @@ -352,16 +355,18 @@ def validate_model_outputs( # noqa: C901, PLR0915 if stress_computed and mg_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{mg_model_output['stress'].shape=} != (1, 3, 3)") + # Test single Fe system output shapes (1 atom) + # This catches that models do not squeeze away singleton dimensions. fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) fe_model_output = model.forward(fe_state) if not torch.allclose( - fe_model_output["energy"], model_output["energy"][2], atol=1e-3 + fe_model_output["energy"], model_output["energy"][2], atol=VALIDATE_ATOL ): raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][2]=}") if not torch.allclose( forces := fe_model_output["forces"], expected_forces := model_output["forces"][si_state.n_atoms + mg_n :], - atol=1e-3, + atol=VALIDATE_ATOL, ): raise ValueError(f"{forces=} != {expected_forces=}") @@ -371,3 +376,34 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{fe_model_output['forces'].shape=} != (1, 3)") if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)") + + # Translating one atom by a full lattice vector should not change outputs. + # This catches models that fail to apply periodic boundary conditions. + shifted_state = si_state.clone() + lattice_vec = shifted_state.cell[0, :, 0] # column convention + shifted_state.positions[0] = shifted_state.positions[0] + 3 * lattice_vec + shifted_output = model.forward(shifted_state) + if not torch.allclose( + shifted_output["energy"], si_model_output["energy"], atol=VALIDATE_ATOL + ): + raise ValueError( + "Energy changed after translating an atom by a lattice " + f"vector: {shifted_output['energy']=} != " + f"{si_model_output['energy']=}" + ) + if force_computed and not torch.allclose( + shifted_output["forces"], si_model_output["forces"], atol=VALIDATE_ATOL + ): + raise ValueError( + "Forces changed after translating an atom by a lattice " + "vector: max diff = " + f"{(shifted_output['forces'] - si_model_output['forces']).abs().max()}" + ) + if stress_computed and not torch.allclose( + shifted_output["stress"], si_model_output["stress"], atol=VALIDATE_ATOL + ): + raise ValueError( + "Stress changed after translating an atom by a lattice " + "vector: max diff = " + f"{(shifted_output['stress'] - si_model_output['stress']).abs().max()}" + ) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 7dfbb7657..e90392c69 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -279,39 +279,24 @@ def forward( # noqa: C901 self._setup_ptr(state.system_idx) self.system_idx = state.system_idx - # Wrap positions into the unit cell - wrapped_positions = ( - ts.transforms.pbc_wrap_batched( - state.positions, - state.cell, - state.system_idx, - state.pbc, - ) - if state.pbc.any() - else state.positions - ) - - # Batched neighbor list using linked-cell algorithm edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( - wrapped_positions, + state.positions, state.row_vector_cell, state.pbc, self.r_max, state.system_idx, ) - # Convert unit cell shift indices to Cartesian shifts shifts = ts.transforms.compute_cell_shifts( state.row_vector_cell, unit_shifts, mapping_system ) - # Build data dict for MACE model data_dict = dict( ptr=self.ptr, node_attrs=self.node_attrs, batch=state.system_idx, pbc=state.pbc, cell=state.row_vector_cell, - positions=wrapped_positions, + positions=state.positions, edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, diff --git a/torch_sim/models/pair_potential.py b/torch_sim/models/pair_potential.py index 1db1c2c9b..bfc4d474d 100644 --- a/torch_sim/models/pair_potential.py +++ b/torch_sim/models/pair_potential.py @@ -74,7 +74,7 @@ def bmhtf_pair(dr, zi, zj, A, B, C, D, sigma): from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.transforms import compute_cell_shifts, pbc_wrap_batched +from torch_sim.transforms import compute_cell_shifts if TYPE_CHECKING: @@ -161,18 +161,12 @@ def _prepare_pairs( else torch.zeros(positions.shape[0], dtype=torch.long, device=device) ) - wrapped_positions = ( - pbc_wrap_batched(positions, sim_state.cell, system_idx, pbc) - if pbc.any() - else positions - ) - pbc_batched = ( pbc.unsqueeze(0).expand(sim_state.n_systems, -1) if pbc.ndim == 1 else pbc ) mapping, system_mapping, shifts_idx = neighbor_list_fn( - positions=wrapped_positions, + positions=positions, cell=row_cell, pbc=pbc_batched, cutoff=cutoff, @@ -185,7 +179,7 @@ def _prepare_pairs( ) cell_shifts = compute_cell_shifts(row_cell, shifts_idx, system_mapping) - dr_vec = wrapped_positions[mapping[1]] - wrapped_positions[mapping[0]] + cell_shifts + dr_vec = positions[mapping[1]] - positions[mapping[0]] + cell_shifts distances = dr_vec.norm(dim=1) return ( diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index d8a5614bb..3c123db20 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -5,9 +5,9 @@ and batched (multi-system) calculations. Available Implementations: - - Primitive: Pure PyTorch implementation (always available) - - Vesin: High-performance neighbor lists (optional, requires vesin package) - - Batched: Optimized for multiple systems (torch_nl_n2, torch_nl_linked_cell) + - Alchemiops: Warp-accelerated neighbor list implementation + - Vesin: High-performance neighbor list implementation + - TorchNL: Pure PyTorch implementation Default Neighbor Lists: The module automatically selects the best available implementation: @@ -16,9 +16,18 @@ import torch -from torch_sim.neighbors.torch_nl import strict_nl as strict_nl -from torch_sim.neighbors.torch_nl import torch_nl_linked_cell -from torch_sim.neighbors.torch_nl import torch_nl_n2 as torch_nl_n2 +from torch_sim.neighbors.alchemiops import ( + ALCHEMIOPS_AVAILABLE, + alchemiops_nl_cell_list, + alchemiops_nl_n2, +) +from torch_sim.neighbors.torch_nl import strict_nl, torch_nl_linked_cell, torch_nl_n2 +from torch_sim.neighbors.vesin import ( + VESIN_AVAILABLE, + VESIN_TORCH_AVAILABLE, + vesin_nl, + vesin_nl_ts, +) def _normalize_inputs( @@ -55,43 +64,15 @@ def _normalize_inputs( return cell, pbc -# Try to import Alchemiops implementations (NVIDIA CUDA acceleration) -try: - from torch_sim.neighbors.alchemiops import ( - ALCHEMIOPS_AVAILABLE, - alchemiops_nl_cell_list, - alchemiops_nl_n2, - ) -except ImportError: - ALCHEMIOPS_AVAILABLE = False - alchemiops_nl_n2 = None # type: ignore[assignment] - alchemiops_nl_cell_list = None # type: ignore[assignment] - -# Try to import Vesin implementations -try: - from torch_sim.neighbors.vesin import ( - VESIN_AVAILABLE, - VesinNeighborList, - VesinNeighborListTorch, - vesin_nl, - vesin_nl_ts, - ) -except ImportError: - VESIN_AVAILABLE = False - VesinNeighborList = None - VesinNeighborListTorch = None - vesin_nl = None # type: ignore[assignment] - vesin_nl_ts = None # type: ignore[assignment] - # Set default neighbor list based on what's available (priority order) if ALCHEMIOPS_AVAILABLE: # Alchemiops is fastest on NVIDIA GPUs default_batched_nl = alchemiops_nl_n2 +elif VESIN_TORCH_AVAILABLE: + default_batched_nl = vesin_nl_ts elif VESIN_AVAILABLE: - # Vesin is good fallback - default_batched_nl = vesin_nl_ts # Still use native for batched + default_batched_nl = vesin_nl else: - # Pure PyTorch fallback default_batched_nl = torch_nl_linked_cell @@ -138,8 +119,28 @@ def torchsim_nl( return alchemiops_nl_n2( positions, cell, pbc, cutoff, system_idx, self_interaction ) - if VESIN_AVAILABLE: + + if VESIN_TORCH_AVAILABLE: return vesin_nl_ts(positions, cell, pbc, cutoff, system_idx, self_interaction) + + if VESIN_AVAILABLE: + return vesin_nl(positions, cell, pbc, cutoff, system_idx, self_interaction) + return torch_nl_linked_cell( positions, cell, pbc, cutoff, system_idx, self_interaction ) + + +__all__ = [ + "ALCHEMIOPS_AVAILABLE", + "ALCHEMIOPS_TORCH_AVAILABLE", + "VESIN_AVAILABLE", + "VESIN_TORCH_AVAILABLE", + "alchemiops_nl_cell_list", + "alchemiops_nl_n2", + "strict_nl", + "torch_nl_linked_cell", + "torch_nl_n2", + "vesin_nl", + "vesin_nl_ts", +] diff --git a/torch_sim/neighbors/alchemiops.py b/torch_sim/neighbors/alchemiops.py index 28c7a614b..bb3dcc451 100644 --- a/torch_sim/neighbors/alchemiops.py +++ b/torch_sim/neighbors/alchemiops.py @@ -1,7 +1,9 @@ """Alchemiops-based neighbor list implementations. -This module provides high-performance CUDA-accelerated neighbor list calculations -using the nvalchemiops library. Supports both naive N^2 and cell list algorithms. +This module provides neighbor lists via nvalchemiops: prefer the PyTorch subtree +(``nvalchemiops.torch.neighbors``), typical for CUDA builds, and fall back to +``nvalchemiops.neighborlist`` when that import path is missing (CPU-oriented API +with the same call surface). Supports naive N^2 and cell-list algorithms. nvalchemiops is available at: https://github.com/NVIDIA/nvalchemiops """ @@ -13,18 +15,33 @@ _batch_cell_list: object | None = None -try: - from nvalchemiops.neighborlist import batch_cell_list as _batch_cell_list - from nvalchemiops.neighborlist import ( - batch_naive_neighbor_list as _batch_naive_neighbor_list, - ) +def _import_nvalchemiops_batch_neighbors() -> tuple[object, object] | None: + """Return ``(batch_cell_list, batch_naive_neighbor_list)`` if a layout is importable. + + Tries ``nvalchemiops.torch.neighbors`` first (PyTorch tensors; usual GPU wheel). + On ``ImportError``, tries ``nvalchemiops.neighborlist`` — same API, CPU fallback + when the ``torch.neighbors`` subtree is absent. + """ + try: + from nvalchemiops.torch.neighbors.batch_cell_list import batch_cell_list as bcl + from nvalchemiops.torch.neighbors.batch_naive import ( + batch_naive_neighbor_list as bnl, + ) + except (ImportError, RuntimeError): + try: + from nvalchemiops.neighborlist import batch_cell_list as bcl + from nvalchemiops.neighborlist import batch_naive_neighbor_list as bnl + except (ImportError, RuntimeError): + return None + return bcl, bnl - ALCHEMIOPS_AVAILABLE = True -except ImportError: - ALCHEMIOPS_AVAILABLE = False +_bound_batch_neighbors = _import_nvalchemiops_batch_neighbors() +ALCHEMIOPS_AVAILABLE = _bound_batch_neighbors is not None if ALCHEMIOPS_AVAILABLE: + assert _bound_batch_neighbors is not None # noqa: S101 + _batch_cell_list, _batch_naive_neighbor_list = _bound_batch_neighbors def alchemiops_nl_n2( positions: torch.Tensor, diff --git a/torch_sim/neighbors/torch_nl.py b/torch_sim/neighbors/torch_nl.py index ca1095cf2..687861e23 100644 --- a/torch_sim/neighbors/torch_nl.py +++ b/torch_sim/neighbors/torch_nl.py @@ -127,14 +127,15 @@ def torch_nl_n2( """Compute the neighbor list for a set of atomic structures using a naive neighbor search before applying a strict `cutoff`. - The atomic positions `pos` should be wrapped inside their respective unit cells. - This implementation uses a naive O(N²) neighbor search which can be slow for large systems but is simple and works reliably for small to medium systems. + Positions are wrapped into the primary cell internally for the search; the + returned ``shifts_idx`` are corrected so they remain valid for the **original** + (unwrapped) input positions. The input tensor is never modified. + Args: - positions (torch.Tensor [n_atom, 3]): A tensor containing the positions - of atoms wrapped inside their respective unit cells. + positions (torch.Tensor [n_atom, 3]): Cartesian positions (may be unwrapped). cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors. pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. @@ -149,43 +150,31 @@ def torch_nl_n2( Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mapping (torch.Tensor [2, n_neighbors]): - A tensor containing the indices of the neighbor list for the given - positions array. `mapping[0]` corresponds to the central atom indices, - and `mapping[1]` corresponds to the neighbor atom indices. + Pairs of atom indices; ``mapping[0]`` are central atoms, + ``mapping[1]`` are neighbors. system_mapping (torch.Tensor [n_neighbors]): - A tensor mapping the neighbor atoms to their respective structures. + System assignment for each pair. shifts_idx (torch.Tensor [n_neighbors, 3]): - A tensor containing the cell shift indices used to reconstruct the - neighbor atom positions. - - Example: - >>> # Create a batched system with 2 structures - >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 5.0, 5.0]]) - >>> cell = torch.eye(3).repeat(2, 1) * 10.0 # Two cells - >>> pbc = torch.tensor([[True, True, True], [True, True, True]]) - >>> cutoff = torch.tensor(2.0) - >>> # First 2 atoms in system 0, last in system 1 - >>> system_idx = torch.tensor([0, 0, 1]) - >>> mapping, sys_map, shifts = torch_nl_n2( - ... positions, cell, pbc, cutoff, system_idx - ... ) + Cell shift indices valid for the **original** input positions. References: - https://github.com/felixmusil/torch_nl - - https://github.com/venkatkapil24/batch_nl: inspired the use of `pad_sequence` - to vectorize a previous implementation that used a loop to iterate over systems - inside the `build_naive_neighborhood` function. + - https://github.com/venkatkapil24/batch_nl """ n_systems = system_idx.max().item() + 1 cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) + wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts( + positions, cell, system_idx, pbc + ) n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( - positions, cell, pbc, cutoff.item(), n_atoms, self_interaction + wrapped, cell, pbc, cutoff.item(), n_atoms, self_interaction ) mapping, mapping_system, shifts_idx = strict_nl( - cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx + cutoff.item(), wrapped, cell, mapping, system_mapping, shifts_idx ) + shifts_idx = shifts_idx + wrap_shifts[mapping[0]] - wrap_shifts[mapping[1]] return mapping, mapping_system, shifts_idx @@ -200,15 +189,16 @@ def torch_nl_linked_cell( """Compute the neighbor list for a set of atomic structures using the linked cell algorithm before applying a strict `cutoff`. - The atomic positions `pos` should be wrapped inside their respective unit cells. + Positions are wrapped into the primary cell internally for the search; the + returned ``shifts_idx`` are corrected so they remain valid for the **original** + (unwrapped) input positions. The input tensor is never modified. This is the recommended default for batched neighbor list calculations as it provides good performance for systems of various sizes using the linked cell algorithm which has O(N) complexity. Args: - positions (torch.Tensor [n_atom, 3]): A tensor containing the positions - of atoms wrapped inside their respective unit cells. + positions (torch.Tensor [n_atom, 3]): Cartesian positions (may be unwrapped). cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors. pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. @@ -222,41 +212,29 @@ def torch_nl_linked_cell( Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - A tuple containing: - - mapping (torch.Tensor [2, n_neighbors]): - A tensor containing the indices of the neighbor list for the given - positions array. `mapping[0]` corresponds to the central atom - indices, and `mapping[1]` corresponds to the neighbor atom indices. - - system_mapping (torch.Tensor [n_neighbors]): - A tensor mapping the neighbor atoms to their respective structures. - - shifts_idx (torch.Tensor [n_neighbors, 3]): - A tensor containing the cell shift indices used to reconstruct the - neighbor atom positions. - - Example: - >>> # Create a batched system with 2 structures - >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 5.0, 5.0]]) - >>> cell = torch.eye(3).repeat(2, 1) * 10.0 # Two cells - >>> pbc = torch.tensor([[True, True, True], [True, True, True]]) - >>> cutoff = torch.tensor(2.0) - >>> # First 2 atoms in system 0, last in system 1 - >>> system_idx = torch.tensor([0, 0, 1]) - >>> mapping, sys_map, shifts = torch_nl_linked_cell( - ... positions, cell, pbc, cutoff, system_idx - ... ) + mapping (torch.Tensor [2, n_neighbors]): + Pairs of atom indices; ``mapping[0]`` are central atoms, + ``mapping[1]`` are neighbors. + system_mapping (torch.Tensor [n_neighbors]): + System assignment for each pair. + shifts_idx (torch.Tensor [n_neighbors, 3]): + Cell shift indices valid for the **original** input positions. References: - https://github.com/felixmusil/torch_nl """ n_systems = system_idx.max().item() + 1 cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) + wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts( + positions, cell, system_idx, pbc + ) n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( - positions, cell, pbc, cutoff.item(), n_atoms, self_interaction + wrapped, cell, pbc, cutoff.item(), n_atoms, self_interaction ) - mapping, mapping_system, shifts_idx = strict_nl( - cutoff.item(), positions, cell, mapping, system_mapping, shifts_idx + cutoff.item(), wrapped, cell, mapping, system_mapping, shifts_idx ) + shifts_idx = shifts_idx + wrap_shifts[mapping[0]] - wrap_shifts[mapping[1]] return mapping, mapping_system, shifts_idx diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index e7b0c27b3..33aa81555 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -12,16 +12,14 @@ try: from vesin import NeighborList as VesinNeighborList except ImportError: - VesinNeighborList = None - -# Try to import torch version (may not exist in all vesin versions) + VesinNeighborList = None # ty:ignore[invalid-assignment] try: from vesin.torch import NeighborList as VesinNeighborListTorch except ImportError: - # vesin.torch may not exist - use regular NeighborList for torch compatibility - VesinNeighborListTorch = VesinNeighborList + VesinNeighborListTorch = None # ty:ignore[invalid-assignment] VESIN_AVAILABLE = VesinNeighborList is not None +VESIN_TORCH_AVAILABLE = VesinNeighborListTorch is not None if VESIN_AVAILABLE: @@ -74,7 +72,10 @@ def vesin_nl_ts( from torch_sim.neighbors import _normalize_inputs if VesinNeighborListTorch is None: - raise RuntimeError("vesin package is not installed") + raise RuntimeError( + "vesin.torch is not available. " + "Install it with: [uv] pip install vesin[torch]" + ) device = positions.device dtype = positions.dtype n_systems = int(system_idx.max().item()) + 1 @@ -199,7 +200,9 @@ def vesin_nl( from torch_sim.neighbors import _normalize_inputs if VesinNeighborList is None: - raise RuntimeError("vesin package is not installed") + raise RuntimeError( + "vesin is not installed. Install it with: [uv] pip install vesin" + ) device = positions.device dtype = positions.dtype n_systems = int(system_idx.max().item()) + 1 diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 27b2926fb..bdc41f1e6 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -139,43 +139,75 @@ def pbc_wrap_batched( """ if isinstance(pbc, bool): pbc = torch.tensor([pbc, pbc, pbc], dtype=torch.bool, device=positions.device) - - # Validate inputs if not torch.is_floating_point(positions) or not torch.is_floating_point(cell): raise TypeError("Positions and lattice vectors must be floating point tensors.") - if positions.shape[-1] != cell.shape[-1]: raise ValueError("Position dimensionality must match lattice vectors.") - - # Get unique system indices and counts uniq_systems = torch.unique(system_idx) n_systems = len(uniq_systems) - if n_systems != cell.shape[0]: raise ValueError( f"Number of unique systems ({n_systems}) doesn't " f"match number of cells ({cell.shape[0]})" ) + pbc_batched = pbc.unsqueeze(0).expand(n_systems, -1) + wrapped, _ = pbc_wrap_batched_and_get_lattice_shifts( + positions, cell.mT, system_idx, pbc_batched + ) + return wrapped + + +def pbc_wrap_batched_and_get_lattice_shifts( + positions: torch.Tensor, + cell: torch.Tensor, + system_idx: torch.Tensor, + pbc: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Wrap Cartesian positions into the primary cell and return the applied shifts. + + ``cell`` rows are lattice vectors (row-vector convention matching + ``compute_cell_shifts`` and the batched neighbor-list APIs). Fractional coordinates + use ``cell_col = cell[s].T`` so ``r = f @ cell[s]`` with ``f`` in ``[0, 1)`` on + periodic axes. Atoms in non-periodic systems or systems with singular cells are + left unchanged. + + Returns ``(wrapped_positions, lattice_shifts)`` where ``lattice_shifts[i]`` is the + integer vector ``floor(frac_i)`` on periodic axes (zero elsewhere). The neighbor-list + code uses these to correct ``shifts_idx`` so they remain valid for the original + (unwrapped) input positions. + """ + cell_T = cell.transpose(1, 2) + dets = torch.linalg.det(cell_T) + invertible = torch.isfinite(dets) & (dets.abs() > 1e-12) + active = pbc.any(dim=1) & invertible + + if not active.any(): + return positions.clone(), torch.zeros_like(positions, dtype=cell.dtype) - # Efficient approach without explicit loops - # Get the cell for each atom based on its system index - B = torch.linalg.inv(cell) # Shape: (n_systems, 3, 3) + # Get the inverse cell for each atom based on its system index + B = torch.zeros_like(cell) # Shape: (n_systems, 3, 3) + B[active] = torch.linalg.inv(cell_T[active]) B_per_atom = B[system_idx] # Shape: (n_atoms, 3, 3) # Transform to fractional coordinates: f = B·r # For each atom, multiply its position by its system's inverse cell matrix - frac_coords = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2) + frac = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2) - # Wrap to reference cell [0,1) using modulo - wrapped_frac = frac_coords.clone() - wrapped_frac[:, pbc] = frac_coords[:, pbc] % 1.0 + pbc_per_atom = pbc[system_idx] + active_per_atom = active[system_idx].unsqueeze(1) + pbc_mask = pbc_per_atom & active_per_atom - # Transform back to real space: r = A·f - # Get the cell for each atom based on its system index - cell_per_atom = cell[system_idx] # Shape: (n_atoms, 3, 3) + # Wrap to reference cell [0,1) using floor + int_shifts = torch.where(pbc_mask, torch.floor(frac), torch.zeros_like(frac)) + wrapped_frac = frac - int_shifts + # Transform back to real space: r = A·f # For each atom, multiply its wrapped fractional coords by its system's cell matrix - return torch.bmm(cell_per_atom, wrapped_frac.unsqueeze(2)).squeeze(2) + cell_per_atom = cell_T[system_idx] # Shape: (n_atoms, 3, 3) + wrapped_pos = torch.bmm(cell_per_atom, wrapped_frac.unsqueeze(2)).squeeze(2) + out = torch.where(active_per_atom, wrapped_pos, positions) + shifts = torch.where(active_per_atom, int_shifts, torch.zeros_like(int_shifts)) + return out, shifts def minimum_image_displacement(