Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,65 +261,77 @@ def test_state_round_trip(
assert torch.allclose(sim_state.masses, round_trip_state.masses)


def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
def test_state_to_atoms_importerror(
monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState
) -> None:
monkeypatch.setitem(sys.modules, "ase", None)
monkeypatch.setitem(sys.modules, "ase.data", None)

with pytest.raises(
ImportError, match="ASE is required for state_to_atoms conversion"
):
ts.io.state_to_atoms(None)
ts.io.state_to_atoms(si_sim_state)


def test_state_to_phonopy_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
def test_state_to_phonopy_importerror(
monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState
) -> None:
monkeypatch.setitem(sys.modules, "phonopy", None)
monkeypatch.setitem(sys.modules, "phonopy.structure", None)
monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None)

with pytest.raises(
ImportError, match="Phonopy is required for state_to_phonopy conversion"
):
ts.io.state_to_phonopy(None)
ts.io.state_to_phonopy(si_sim_state)


def test_state_to_structures_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
def test_state_to_structures_importerror(
monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState
) -> None:
monkeypatch.setitem(sys.modules, "pymatgen", None)
monkeypatch.setitem(sys.modules, "pymatgen.core", None)
monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None)

with pytest.raises(
ImportError, match="Pymatgen is required for state_to_structures conversion"
):
ts.io.state_to_structures(None)
ts.io.state_to_structures(si_sim_state)


def test_atoms_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
def test_atoms_to_state_importerror(
monkeypatch: pytest.MonkeyPatch, si_atoms: Atoms
) -> None:
monkeypatch.setitem(sys.modules, "ase", None)
monkeypatch.setitem(sys.modules, "ase.data", None)

with pytest.raises(
ImportError, match="ASE is required for atoms_to_state conversion"
):
ts.io.atoms_to_state(None, None, None)
ts.io.atoms_to_state(si_atoms, torch.device("cpu"), torch.float64)


def test_phonopy_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
def test_phonopy_to_state_importerror(
monkeypatch: pytest.MonkeyPatch, si_phonopy_atoms: PhonopyAtoms
) -> None:
monkeypatch.setitem(sys.modules, "phonopy", None)
monkeypatch.setitem(sys.modules, "phonopy.structure", None)
monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None)

with pytest.raises(
ImportError, match="Phonopy is required for phonopy_to_state conversion"
):
ts.io.phonopy_to_state(None, None, None)
ts.io.phonopy_to_state(si_phonopy_atoms, torch.device("cpu"), torch.float64)


def test_structures_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
def test_structures_to_state_importerror(
monkeypatch: pytest.MonkeyPatch, si_structure: Structure
) -> None:
monkeypatch.setitem(sys.modules, "pymatgen", None)
monkeypatch.setitem(sys.modules, "pymatgen.core", None)
monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None)

with pytest.raises(
ImportError, match="Pymatgen is required for structures_to_state conversion"
):
ts.io.structures_to_state(None, None, None)
ts.io.structures_to_state(si_structure, torch.device("cpu"), torch.float64)
61 changes: 36 additions & 25 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from collections.abc import Callable

import numpy as np
import psutil
Expand Down Expand Up @@ -104,7 +105,7 @@ def ase_to_torch_batch(


@pytest.fixture
def periodic_atoms_set():
def periodic_atoms_set() -> list[Atoms]:
return [
bulk("Si", "diamond", a=6, cubic=True),
bulk("Si", "diamond", a=6),
Expand All @@ -123,19 +124,21 @@ def periodic_atoms_set():


@pytest.fixture
def molecule_atoms_set() -> list:
def molecule_atoms_set() -> list[Atoms]:
return [
*map(molecule, ("CH3CH2NH2", "H2O", "methylenecyclopropane", "OCHCHO", "C3H9C")),
]


@pytest.mark.parametrize("cutoff", [1, 3, 5, 7])
@pytest.mark.parametrize("use_jit", [True, False])
@pytest.mark.parametrize("atoms_list", ["periodic_atoms_set", "molecule_atoms_set"])
@pytest.mark.parametrize(
"atoms_list_fixture", ["periodic_atoms_set", "molecule_atoms_set"]
)
def test_primitive_neighbor_list(
*,
cutoff: float,
atoms_list: str,
atoms_list_fixture: str,
device: torch.device,
dtype: torch.dtype,
use_jit: bool,
Expand All @@ -150,7 +153,7 @@ def test_primitive_neighbor_list(
dtype: Torch dtype to use
use_jit: Whether to use the jitted version or disable JIT
"""
atoms_list = request.getfixturevalue(atoms_list)
atoms_list = request.getfixturevalue(atoms_list_fixture)

# Create a non-jitted version of the function if requested
if use_jit:
Expand Down Expand Up @@ -209,7 +212,7 @@ def test_primitive_neighbor_list(
dds_prim = transforms.compute_distances_with_cell_shifts(
pos, mapping, cell_shifts_prim
)
dds_prim = np.sort(dds_prim.numpy())
dds_prim_sorted = np.sort(dds_prim.numpy())

# Get the neighbor list from ase
idx_i_ref, idx_j_ref, shifts_ref, dist_ref = neighbor_list(
Expand All @@ -235,37 +238,39 @@ def test_primitive_neighbor_list(
)

# Sort the distances
dds_ref = np.sort(dds_ref.numpy())
dist_ref = np.sort(dist_ref)
dds_ref_sorted = np.sort(dds_ref.numpy())
dist_ref_sorted = np.sort(dist_ref)

# Check that the distances are the same with ase and torchsim logic
np.testing.assert_allclose(dds_ref, dist_ref)
np.testing.assert_allclose(dds_ref_sorted, dist_ref_sorted)

# Check that the primitive_neighbor_list distances match ASE's
np.testing.assert_allclose(
dds_prim, dist_ref, err_msg=f"Failed with use_jit={use_jit}"
dds_prim_sorted, dist_ref_sorted, err_msg=f"Failed with use_jit={use_jit}"
)


@pytest.mark.parametrize("cutoff", [1, 3, 5, 7])
@pytest.mark.parametrize("atoms_list", ["periodic_atoms_set", "molecule_atoms_set"])
@pytest.mark.parametrize(
"atoms_list_fixture", ["periodic_atoms_set", "molecule_atoms_set"]
)
@pytest.mark.parametrize(
"nl_implementation",
[neighbors.standard_nl, neighbors.vesin_nl, neighbors.vesin_nl_ts],
)
def test_neighbor_list_implementations(
*,
cutoff: float,
atoms_list: str,
nl_implementation: callable,
atoms_list_fixture: str,
nl_implementation: Callable,
device: torch.device,
dtype: torch.dtype,
request: pytest.FixtureRequest,
) -> None:
"""Check that different neighbor list implementations give the same results as ASE
by comparing the resulting sorted list of distances between neighbors.
"""
atoms_list = request.getfixturevalue(atoms_list)
atoms_list = request.getfixturevalue(atoms_list_fixture)

for atoms in atoms_list:
# Convert to torch tensors
Expand All @@ -284,7 +289,7 @@ def test_neighbor_list_implementations(
# Calculate distances with cell shifts
cell_shifts = torch.mm(shifts, row_vector_cell)
dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
dds = np.sort(dds.numpy())
dds_sorted = np.sort(dds.numpy())

# Get the reference neighbor list from ASE
idx_i, idx_j, shifts_ref, dist = neighbor_list(
Expand All @@ -306,13 +311,13 @@ def test_neighbor_list_implementations(
dds_ref = transforms.compute_distances_with_cell_shifts(
pos, mapping_ref, cell_shifts_ref
)
dds_ref = np.sort(dds_ref.numpy())
dist_ref = np.sort(dist)
dds_ref_sorted = np.sort(dds_ref.numpy())
dist_ref_sorted = np.sort(dist)

# Verify results
np.testing.assert_allclose(dds_ref, dist_ref)
np.testing.assert_allclose(dds, dds_ref)
np.testing.assert_allclose(dds, dist_ref)
np.testing.assert_allclose(dds_ref_sorted, dist_ref_sorted)
np.testing.assert_allclose(dds_sorted, dds_ref_sorted)
np.testing.assert_allclose(dds_sorted, dist_ref_sorted)


@pytest.mark.parametrize("cutoff", [1, 3, 5, 7])
Expand All @@ -325,7 +330,7 @@ def test_torch_nl_implementations(
*,
cutoff: float,
self_interaction: bool,
nl_implementation: callable,
nl_implementation: Callable,
device: torch.device,
dtype: torch.dtype,
molecule_atoms_set: list[Atoms],
Expand All @@ -351,7 +356,7 @@ def test_torch_nl_implementations(
row_vector_cell, shifts_idx, mapping_system
)
dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
dds = np.sort(dds.numpy())
dds_sorted = np.sort(dds.numpy())

# Get reference results from ASE
dd_ref = []
Expand All @@ -364,10 +369,10 @@ def test_torch_nl_implementations(
max_nbins=1e6,
)
dd_ref.extend(dist)
dd_ref = np.sort(dd_ref)
dd_ref_sorted = np.sort(dd_ref)

# Verify results
np.testing.assert_allclose(dd_ref, dds)
np.testing.assert_allclose(dd_ref_sorted, dds_sorted)


def test_primitive_neighbor_list_edge_cases(
Expand Down Expand Up @@ -560,7 +565,13 @@ def test_neighbor_lists_time_and_memory(
# Fix pbc tensor shape
pbc = torch.tensor([[True, True, True]], device=device)
mapping, mapping_system, shifts_idx = nl_fn(
cutoff, pos, cell, pbc, system_idx, self_interaction=False
cutoff=cutoff,
positions=pos,
cell=cell,
# TODO: standardize all pbc so we either use tensors/booleans/tuples.
pbc=pbc, # type: ignore[arg-type]
system_idx=system_idx,
self_interaction=False, # type: ignore[call-arg, misc]
)
else:
mapping, shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff)
Expand Down
20 changes: 11 additions & 9 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def test_pbc_wrap_batched_preserves_relative_positions(
system_idx_mask = state.system_idx == b

# Calculate pairwise distances before wrapping
atoms_in_batch = torch.sum(system_idx_mask).item()
atoms_in_batch = int(torch.sum(system_idx_mask).item())
for n_atoms in range(atoms_in_batch - 1):
for j in range(n_atoms + 1, atoms_in_batch):
# Get the indices of atoms i and j in this batch
Expand Down Expand Up @@ -698,7 +698,7 @@ def constant_fn(dr: torch.Tensor) -> torch.Tensor:
return torch.ones_like(dr)

cutoff_fn = tst.multiplicative_isotropic_cutoff(
constant_fn, r_onset=1.0, r_cutoff=2.0
constant_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0)
)

# Test points in different regions
Expand All @@ -716,8 +716,8 @@ def test_multiplicative_isotropic_cutoff_continuity() -> None:
def linear_fn(dr: torch.Tensor) -> torch.Tensor:
return dr

r_onset = 1.0
r_cutoff = 2.0
r_onset = torch.tensor(1.0)
r_cutoff = torch.tensor(2.0)
cutoff_fn = tst.multiplicative_isotropic_cutoff(linear_fn, r_onset, r_cutoff)

# Test near onset
Expand All @@ -741,8 +741,8 @@ def test_multiplicative_isotropic_cutoff_derivative_continuity() -> None:
def quadratic_fn(dr: torch.Tensor) -> torch.Tensor:
return dr**2

r_onset = 1.0
r_cutoff = 2.0
r_onset = torch.tensor(1.0)
r_cutoff = torch.tensor(2.0)
cutoff_fn = tst.multiplicative_isotropic_cutoff(quadratic_fn, r_onset, r_cutoff)

# Test derivative near onset and cutoff using finite differences
Expand All @@ -764,7 +764,7 @@ def parameterized_fn(dr: torch.Tensor, scale: float) -> torch.Tensor:
return scale * dr

cutoff_fn = tst.multiplicative_isotropic_cutoff(
parameterized_fn, r_onset=1.0, r_cutoff=2.0
parameterized_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0)
)

dr = torch.tensor([0.5, 1.5, 2.5])
Expand All @@ -782,7 +782,7 @@ def constant_fn(dr: torch.Tensor) -> torch.Tensor:
return torch.ones_like(dr)

cutoff_fn = tst.multiplicative_isotropic_cutoff(
constant_fn, r_onset=1.0, r_cutoff=2.0
constant_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0)
)

# Test with 2D input
Expand All @@ -800,7 +800,9 @@ def test_multiplicative_isotropic_cutoff_gradient() -> None:
def linear_fn(dr: torch.Tensor) -> torch.Tensor:
return dr

cutoff_fn = tst.multiplicative_isotropic_cutoff(linear_fn, r_onset=1.0, r_cutoff=2.0)
cutoff_fn = tst.multiplicative_isotropic_cutoff(
linear_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0)
)

dr = torch.tensor([1.5], requires_grad=True)
result = cutoff_fn(dr)
Expand Down
Loading