diff --git a/tests/test_io.py b/tests/test_io.py index c5043763a..e90f6ab59 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -261,17 +261,21 @@ def test_state_round_trip( assert torch.allclose(sim_state.masses, round_trip_state.masses) -def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: +def test_state_to_atoms_importerror( + monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState +) -> None: monkeypatch.setitem(sys.modules, "ase", None) monkeypatch.setitem(sys.modules, "ase.data", None) with pytest.raises( ImportError, match="ASE is required for state_to_atoms conversion" ): - ts.io.state_to_atoms(None) + ts.io.state_to_atoms(si_sim_state) -def test_state_to_phonopy_importerror(monkeypatch: pytest.MonkeyPatch) -> None: +def test_state_to_phonopy_importerror( + monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState +) -> None: monkeypatch.setitem(sys.modules, "phonopy", None) monkeypatch.setitem(sys.modules, "phonopy.structure", None) monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) @@ -279,10 +283,12 @@ def test_state_to_phonopy_importerror(monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises( ImportError, match="Phonopy is required for state_to_phonopy conversion" ): - ts.io.state_to_phonopy(None) + ts.io.state_to_phonopy(si_sim_state) -def test_state_to_structures_importerror(monkeypatch: pytest.MonkeyPatch) -> None: +def test_state_to_structures_importerror( + monkeypatch: pytest.MonkeyPatch, si_sim_state: ts.SimState +) -> None: monkeypatch.setitem(sys.modules, "pymatgen", None) monkeypatch.setitem(sys.modules, "pymatgen.core", None) monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) @@ -290,20 +296,24 @@ def test_state_to_structures_importerror(monkeypatch: pytest.MonkeyPatch) -> Non with pytest.raises( ImportError, match="Pymatgen is required for state_to_structures conversion" ): - ts.io.state_to_structures(None) + ts.io.state_to_structures(si_sim_state) -def test_atoms_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: +def test_atoms_to_state_importerror( + monkeypatch: pytest.MonkeyPatch, si_atoms: Atoms +) -> None: monkeypatch.setitem(sys.modules, "ase", None) monkeypatch.setitem(sys.modules, "ase.data", None) with pytest.raises( ImportError, match="ASE is required for atoms_to_state conversion" ): - ts.io.atoms_to_state(None, None, None) + ts.io.atoms_to_state(si_atoms, torch.device("cpu"), torch.float64) -def test_phonopy_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: +def test_phonopy_to_state_importerror( + monkeypatch: pytest.MonkeyPatch, si_phonopy_atoms: PhonopyAtoms +) -> None: monkeypatch.setitem(sys.modules, "phonopy", None) monkeypatch.setitem(sys.modules, "phonopy.structure", None) monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) @@ -311,10 +321,12 @@ def test_phonopy_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises( ImportError, match="Phonopy is required for phonopy_to_state conversion" ): - ts.io.phonopy_to_state(None, None, None) + ts.io.phonopy_to_state(si_phonopy_atoms, torch.device("cpu"), torch.float64) -def test_structures_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: +def test_structures_to_state_importerror( + monkeypatch: pytest.MonkeyPatch, si_structure: Structure +) -> None: monkeypatch.setitem(sys.modules, "pymatgen", None) monkeypatch.setitem(sys.modules, "pymatgen.core", None) monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) @@ -322,4 +334,4 @@ def test_structures_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> Non with pytest.raises( ImportError, match="Pymatgen is required for structures_to_state conversion" ): - ts.io.structures_to_state(None, None, None) + ts.io.structures_to_state(si_structure, torch.device("cpu"), torch.float64) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index ac948acbb..dbf625101 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -1,4 +1,5 @@ import time +from collections.abc import Callable import numpy as np import psutil @@ -104,7 +105,7 @@ def ase_to_torch_batch( @pytest.fixture -def periodic_atoms_set(): +def periodic_atoms_set() -> list[Atoms]: return [ bulk("Si", "diamond", a=6, cubic=True), bulk("Si", "diamond", a=6), @@ -123,7 +124,7 @@ def periodic_atoms_set(): @pytest.fixture -def molecule_atoms_set() -> list: +def molecule_atoms_set() -> list[Atoms]: return [ *map(molecule, ("CH3CH2NH2", "H2O", "methylenecyclopropane", "OCHCHO", "C3H9C")), ] @@ -131,11 +132,13 @@ def molecule_atoms_set() -> list: @pytest.mark.parametrize("cutoff", [1, 3, 5, 7]) @pytest.mark.parametrize("use_jit", [True, False]) -@pytest.mark.parametrize("atoms_list", ["periodic_atoms_set", "molecule_atoms_set"]) +@pytest.mark.parametrize( + "atoms_list_fixture", ["periodic_atoms_set", "molecule_atoms_set"] +) def test_primitive_neighbor_list( *, cutoff: float, - atoms_list: str, + atoms_list_fixture: str, device: torch.device, dtype: torch.dtype, use_jit: bool, @@ -150,7 +153,7 @@ def test_primitive_neighbor_list( dtype: Torch dtype to use use_jit: Whether to use the jitted version or disable JIT """ - atoms_list = request.getfixturevalue(atoms_list) + atoms_list = request.getfixturevalue(atoms_list_fixture) # Create a non-jitted version of the function if requested if use_jit: @@ -209,7 +212,7 @@ def test_primitive_neighbor_list( dds_prim = transforms.compute_distances_with_cell_shifts( pos, mapping, cell_shifts_prim ) - dds_prim = np.sort(dds_prim.numpy()) + dds_prim_sorted = np.sort(dds_prim.numpy()) # Get the neighbor list from ase idx_i_ref, idx_j_ref, shifts_ref, dist_ref = neighbor_list( @@ -235,20 +238,22 @@ def test_primitive_neighbor_list( ) # Sort the distances - dds_ref = np.sort(dds_ref.numpy()) - dist_ref = np.sort(dist_ref) + dds_ref_sorted = np.sort(dds_ref.numpy()) + dist_ref_sorted = np.sort(dist_ref) # Check that the distances are the same with ase and torchsim logic - np.testing.assert_allclose(dds_ref, dist_ref) + np.testing.assert_allclose(dds_ref_sorted, dist_ref_sorted) # Check that the primitive_neighbor_list distances match ASE's np.testing.assert_allclose( - dds_prim, dist_ref, err_msg=f"Failed with use_jit={use_jit}" + dds_prim_sorted, dist_ref_sorted, err_msg=f"Failed with use_jit={use_jit}" ) @pytest.mark.parametrize("cutoff", [1, 3, 5, 7]) -@pytest.mark.parametrize("atoms_list", ["periodic_atoms_set", "molecule_atoms_set"]) +@pytest.mark.parametrize( + "atoms_list_fixture", ["periodic_atoms_set", "molecule_atoms_set"] +) @pytest.mark.parametrize( "nl_implementation", [neighbors.standard_nl, neighbors.vesin_nl, neighbors.vesin_nl_ts], @@ -256,8 +261,8 @@ def test_primitive_neighbor_list( def test_neighbor_list_implementations( *, cutoff: float, - atoms_list: str, - nl_implementation: callable, + atoms_list_fixture: str, + nl_implementation: Callable, device: torch.device, dtype: torch.dtype, request: pytest.FixtureRequest, @@ -265,7 +270,7 @@ def test_neighbor_list_implementations( """Check that different neighbor list implementations give the same results as ASE by comparing the resulting sorted list of distances between neighbors. """ - atoms_list = request.getfixturevalue(atoms_list) + atoms_list = request.getfixturevalue(atoms_list_fixture) for atoms in atoms_list: # Convert to torch tensors @@ -284,7 +289,7 @@ def test_neighbor_list_implementations( # Calculate distances with cell shifts cell_shifts = torch.mm(shifts, row_vector_cell) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) - dds = np.sort(dds.numpy()) + dds_sorted = np.sort(dds.numpy()) # Get the reference neighbor list from ASE idx_i, idx_j, shifts_ref, dist = neighbor_list( @@ -306,13 +311,13 @@ def test_neighbor_list_implementations( dds_ref = transforms.compute_distances_with_cell_shifts( pos, mapping_ref, cell_shifts_ref ) - dds_ref = np.sort(dds_ref.numpy()) - dist_ref = np.sort(dist) + dds_ref_sorted = np.sort(dds_ref.numpy()) + dist_ref_sorted = np.sort(dist) # Verify results - np.testing.assert_allclose(dds_ref, dist_ref) - np.testing.assert_allclose(dds, dds_ref) - np.testing.assert_allclose(dds, dist_ref) + np.testing.assert_allclose(dds_ref_sorted, dist_ref_sorted) + np.testing.assert_allclose(dds_sorted, dds_ref_sorted) + np.testing.assert_allclose(dds_sorted, dist_ref_sorted) @pytest.mark.parametrize("cutoff", [1, 3, 5, 7]) @@ -325,7 +330,7 @@ def test_torch_nl_implementations( *, cutoff: float, self_interaction: bool, - nl_implementation: callable, + nl_implementation: Callable, device: torch.device, dtype: torch.dtype, molecule_atoms_set: list[Atoms], @@ -351,7 +356,7 @@ def test_torch_nl_implementations( row_vector_cell, shifts_idx, mapping_system ) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) - dds = np.sort(dds.numpy()) + dds_sorted = np.sort(dds.numpy()) # Get reference results from ASE dd_ref = [] @@ -364,10 +369,10 @@ def test_torch_nl_implementations( max_nbins=1e6, ) dd_ref.extend(dist) - dd_ref = np.sort(dd_ref) + dd_ref_sorted = np.sort(dd_ref) # Verify results - np.testing.assert_allclose(dd_ref, dds) + np.testing.assert_allclose(dd_ref_sorted, dds_sorted) def test_primitive_neighbor_list_edge_cases( @@ -560,7 +565,13 @@ def test_neighbor_lists_time_and_memory( # Fix pbc tensor shape pbc = torch.tensor([[True, True, True]], device=device) mapping, mapping_system, shifts_idx = nl_fn( - cutoff, pos, cell, pbc, system_idx, self_interaction=False + cutoff=cutoff, + positions=pos, + cell=cell, + # TODO: standardize all pbc so we either use tensors/booleans/tuples. + pbc=pbc, # type: ignore[arg-type] + system_idx=system_idx, + self_interaction=False, # type: ignore[call-arg, misc] ) else: mapping, shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index ca965c695..c9317cdf9 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -456,7 +456,7 @@ def test_pbc_wrap_batched_preserves_relative_positions( system_idx_mask = state.system_idx == b # Calculate pairwise distances before wrapping - atoms_in_batch = torch.sum(system_idx_mask).item() + atoms_in_batch = int(torch.sum(system_idx_mask).item()) for n_atoms in range(atoms_in_batch - 1): for j in range(n_atoms + 1, atoms_in_batch): # Get the indices of atoms i and j in this batch @@ -698,7 +698,7 @@ def constant_fn(dr: torch.Tensor) -> torch.Tensor: return torch.ones_like(dr) cutoff_fn = tst.multiplicative_isotropic_cutoff( - constant_fn, r_onset=1.0, r_cutoff=2.0 + constant_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) ) # Test points in different regions @@ -716,8 +716,8 @@ def test_multiplicative_isotropic_cutoff_continuity() -> None: def linear_fn(dr: torch.Tensor) -> torch.Tensor: return dr - r_onset = 1.0 - r_cutoff = 2.0 + r_onset = torch.tensor(1.0) + r_cutoff = torch.tensor(2.0) cutoff_fn = tst.multiplicative_isotropic_cutoff(linear_fn, r_onset, r_cutoff) # Test near onset @@ -741,8 +741,8 @@ def test_multiplicative_isotropic_cutoff_derivative_continuity() -> None: def quadratic_fn(dr: torch.Tensor) -> torch.Tensor: return dr**2 - r_onset = 1.0 - r_cutoff = 2.0 + r_onset = torch.tensor(1.0) + r_cutoff = torch.tensor(2.0) cutoff_fn = tst.multiplicative_isotropic_cutoff(quadratic_fn, r_onset, r_cutoff) # Test derivative near onset and cutoff using finite differences @@ -764,7 +764,7 @@ def parameterized_fn(dr: torch.Tensor, scale: float) -> torch.Tensor: return scale * dr cutoff_fn = tst.multiplicative_isotropic_cutoff( - parameterized_fn, r_onset=1.0, r_cutoff=2.0 + parameterized_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) ) dr = torch.tensor([0.5, 1.5, 2.5]) @@ -782,7 +782,7 @@ def constant_fn(dr: torch.Tensor) -> torch.Tensor: return torch.ones_like(dr) cutoff_fn = tst.multiplicative_isotropic_cutoff( - constant_fn, r_onset=1.0, r_cutoff=2.0 + constant_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) ) # Test with 2D input @@ -800,7 +800,9 @@ def test_multiplicative_isotropic_cutoff_gradient() -> None: def linear_fn(dr: torch.Tensor) -> torch.Tensor: return dr - cutoff_fn = tst.multiplicative_isotropic_cutoff(linear_fn, r_onset=1.0, r_cutoff=2.0) + cutoff_fn = tst.multiplicative_isotropic_cutoff( + linear_fn, r_onset=torch.tensor(1.0), r_cutoff=torch.tensor(2.0) + ) dr = torch.tensor([1.5], requires_grad=True) result = cutoff_fn(dr) diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 40091eee8..05cba999e 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -174,7 +174,7 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 bin_index_ic[:, c], n_bins_c[c] ) else: - bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, n_bins_c[c] - 1) + bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, n_bins_c[c] - 1) # type: ignore[call-overload] # Convert Cartesian bin index to unique scalar bin index. bin_index_i = bin_index_ic[:, 0] + n_bins_c[0] * ( @@ -193,8 +193,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 # homogeneous, i.e. has the same size *max_n_atoms_per_bin* for all bins. # The list is padded with -1 values. atoms_in_bin_ba = -torch.ones( - n_bins, max_n_atoms_per_bin.item(), dtype=torch.long, device=device - ) + n_bins, max_n_atoms_per_bin, dtype=torch.long, device=device + ) # type: ignore[call-overload] for bin_cnt in range(int(max_n_atoms_per_bin.item())): # Create a mask array that identifies the first atom of each bin. mask = torch.cat( @@ -227,8 +227,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 # (max_n_atoms_per_bin, max_n_atoms_per_bin), dtype=int # ).reshape(2, -1) atom_pairs_pn = torch.cartesian_prod( - torch.arange(max_n_atoms_per_bin, device=device), - torch.arange(max_n_atoms_per_bin, device=device), + torch.arange(max_n_atoms_per_bin, device=device), # type: ignore[call-overload] + torch.arange(max_n_atoms_per_bin, device=device), # type: ignore[call-overload] ) atom_pairs_pn = atom_pairs_pn.T.reshape(2, -1) @@ -244,9 +244,9 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 # that each bin contains exactly max_n_atoms_per_bin atoms. We then throw # out pairs involving pad atoms with atom index -1 below. binz_xyz, biny_xyz, binx_xyz = torch.meshgrid( - torch.arange(n_bins_c[2], device=device), - torch.arange(n_bins_c[1], device=device), - torch.arange(n_bins_c[0], device=device), + torch.arange(n_bins_c[2], device=device), # type: ignore[call-overload] + torch.arange(n_bins_c[1], device=device), # type: ignore[call-overload] + torch.arange(n_bins_c[0], device=device), # type: ignore[call-overload] indexing="ij", ) # The memory layout of binx_xyz, biny_xyz, binz_xyz is such that computing @@ -363,10 +363,10 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 cell_shift_vector_n = cell_shift_vector_n[m] # Sort neighbor list. - bin_cnt = torch.argsort(first_at_neigh_tuple_n) - first_at_neigh_tuple_n = first_at_neigh_tuple_n[bin_cnt] - second_at_neigh_tuple_n = second_at_neigh_tuple_n[bin_cnt] - cell_shift_vector_n = cell_shift_vector_n[bin_cnt] + bin_cnt_sort_idx = torch.argsort(first_at_neigh_tuple_n) + first_at_neigh_tuple_n = first_at_neigh_tuple_n[bin_cnt_sort_idx] + second_at_neigh_tuple_n = second_at_neigh_tuple_n[bin_cnt_sort_idx] + cell_shift_vector_n = cell_shift_vector_n[bin_cnt_sort_idx] # Compute distance vectors. # TODO: Use .T? @@ -640,7 +640,7 @@ def vesin_nl( def strict_nl( cutoff: float, positions: torch.Tensor, - cell: torch.Tensor, + cell: torch.Tensor | None, mapping: torch.Tensor, system_mapping: torch.Tensor, shifts_idx: torch.Tensor, @@ -658,8 +658,8 @@ def strict_nl( is used to filter the neighbor pairs based on their distances. positions (torch.Tensor): A tensor of shape (n_atoms, 3) representing the positions of the atoms. - cell (torch.Tensor): Unit cell vectors according to the row vector convention, - i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. + cell (torch.Tensor | None): Unit cell vectors according to the row vector + convention. i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. mapping (torch.Tensor): A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` for which to compute distances. @@ -689,10 +689,12 @@ def strict_nl( References: - https://github.com/felixmusil/torch_nl """ - cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping) - if cell_shifts is None: + if cell is None: d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1) else: + cell_shifts = transforms.compute_cell_shifts_strict( + cell, shifts_idx, system_mapping + ) d2 = ( (positions[mapping[0]] - positions[mapping[1]] - cell_shifts) .square() diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 29b6fa9ab..65947e8db 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -5,7 +5,7 @@ general PBC wrapping. """ -from collections.abc import Callable, Iterable +from collections.abc import Callable from functools import wraps import torch @@ -349,7 +349,7 @@ def wrap_positions( cell: torch.Tensor, *, pbc: bool | list[bool] | torch.Tensor = True, - center: tuple[float, float, float] = (0.5, 0.5, 0.5), + center: tuple[float, float, float] | float = (0.5, 0.5, 0.5), pretty_translation: bool = False, eps: float = 1e-7, ) -> torch.Tensor: @@ -374,9 +374,10 @@ def wrap_positions( device = positions.device # Convert center to tensor - if not hasattr(center, "__len__"): - center = (center,) * 3 - center = torch.tensor(center, dtype=positions.dtype, device=device) + if isinstance(center, float): + center_pos = torch.tensor((center,) * 3, dtype=positions.dtype, device=device) + else: + center_pos = torch.tensor(center, dtype=positions.dtype, device=device) # Handle PBC input if isinstance(pbc, bool): @@ -385,7 +386,7 @@ def wrap_positions( pbc = torch.tensor(pbc, dtype=torch.bool, device=device) # Calculate shift based on center - shift = center - 0.5 - eps + shift = center_pos - 0.5 - eps shift[~pbc] = 0.0 # Convert positions to fractional coordinates @@ -393,7 +394,7 @@ def wrap_positions( if pretty_translation: fractional = translate_pretty(fractional, pbc) - shift = center - 0.5 + shift = center_pos - 0.5 shift[~pbc] = 0.0 fractional += shift else: @@ -488,7 +489,7 @@ def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: _dtype) -> torch.Tensor num_repeats[ii] + 1, device=num_repeats.device, dtype=dtype, - ) + ) # type: ignore[call-overload] _, indices = torch.sort(torch.abs(r1)) reps.append(r1[indices]) return torch.cartesian_prod(reps[0], reps[1], reps[2]) @@ -497,7 +498,7 @@ def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: _dtype) -> torch.Tensor def compute_distances_with_cell_shifts( pos: torch.Tensor, mapping: torch.Tensor, - cell_shifts: torch.Tensor, + cell_shifts: torch.Tensor | None, ) -> torch.Tensor: """Compute distances between pairs of positions, optionally including cell shifts. @@ -513,7 +514,7 @@ def compute_distances_with_cell_shifts( mapping (torch.Tensor): A tensor of shape (2, n_pairs) that specifies pairs of indices in `pos` for which to compute distances. - cell_shifts (Optional[torch.Tensor]): A tensor of shape (n_pairs, 3) + cell_shifts (torch.Tensor | None): A tensor of shape (n_pairs, 3) representing the shifts to apply to the distances for periodic boundary conditions. If None, no shifts are applied. @@ -535,15 +536,15 @@ def compute_distances_with_cell_shifts( def compute_cell_shifts( - cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor -) -> torch.Tensor: + cell: torch.Tensor | None, shifts_idx: torch.Tensor, system_mapping: torch.Tensor +) -> torch.Tensor | None: """Compute the cell shifts based on the provided indices and cell matrix. This function calculates the shifts to apply to positions based on the specified indices and the unit cell matrix. If the cell is None, it returns None. Args: - cell (torch.Tensor): A tensor of shape (n_cells, 3, 3) + cell (torch.Tensor | None): A tensor of shape (n_cells, 3, 3) representing the unit cell matrices. shifts_idx (torch.Tensor): A tensor of shape (n_shifts, 3) representing the indices for shifts. @@ -555,12 +556,17 @@ def compute_cell_shifts( the computed cell shifts. """ if cell is None: - cell_shifts = None - else: - cell_shifts = torch.einsum( - "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping] - ) - return cell_shifts + return None + return compute_cell_shifts_strict(cell, shifts_idx, system_mapping) + + +def compute_cell_shifts_strict( + cell: torch.Tensor, shifts_idx: torch.Tensor, system_mapping: torch.Tensor +) -> torch.Tensor: + """Same thing as compute_cell_shifts, but cell cannot be None. + Having a non-optional cell makes torchjit not complain. + """ + return torch.einsum("jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[system_mapping]) def get_fully_connected_mapping( @@ -848,7 +854,7 @@ def linked_cell( # noqa: PLR0915 shifts_idx, n_atom, dim=0, output_size=n_atom * n_cell_image ) batch_image = torch.zeros((shifts_idx.shape[0]), dtype=torch.long) - cell_shifts = compute_cell_shifts(cell.view(-1, 3, 3), shifts_idx, batch_image) + cell_shifts = compute_cell_shifts_strict(cell.view(-1, 3, 3), shifts_idx, batch_image) i_ids = torch.arange(n_atom, device=device, dtype=torch.long) i_ids = i_ids.repeat(n_cell_image) @@ -1101,7 +1107,7 @@ def cutoff_fn(dr: torch.Tensor, *args, **kwargs) -> torch.Tensor: def high_precision_sum( x: torch.Tensor, - dim: int | Iterable[int] | None = None, + dim: int | list[int] | tuple[int, ...] | None = None, *, keepdim: bool = False, ) -> torch.Tensor: @@ -1138,7 +1144,7 @@ def high_precision_sum( def safe_mask( mask: torch.Tensor, - fn: torch.jit.ScriptFunction, + fn: Callable[..., torch.Tensor], operand: torch.Tensor, placeholder: float = 0.0, ) -> torch.Tensor: @@ -1150,7 +1156,7 @@ def safe_mask( Args: mask: Boolean tensor indicating which elements to process (True) or mask (False) - fn: TorchScript function to apply to the masked elements + fn: callable function to apply to the masked elements operand: Input tensor to apply the function to placeholder: Value to use for masked-out positions (default: 0.0)