Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class JAXBackend(Backend):
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.IO
# Backend.Feature.ENTRY_POINT
| Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".jax"]
Expand Down Expand Up @@ -82,7 +82,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
Expand Down
35 changes: 18 additions & 17 deletions deepmd/dpmodel/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -68,42 +69,42 @@ def call(
np.ndarray
The maximal number of neighbors
"""
xp = array_api_compat.array_namespace(coord, atype)
nframes = coord.shape[0]
coord = coord.reshape(nframes, -1, 3)
coord = xp.reshape(coord, (nframes, -1, 3))
nloc = coord.shape[1]
coord = coord.reshape(nframes, nloc * 3)
coord = xp.reshape(coord, (nframes, nloc * 3))
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
coord, atype, cell, self.rcut
)

coord1 = extend_coord.reshape(nframes, -1)
coord1 = xp.reshape(extend_coord, (nframes, -1))
nall = coord1.shape[1] // 3
coord0 = coord1[:, : nloc * 3]
diff = (
coord1.reshape([nframes, -1, 3])[:, None, :, :]
- coord0.reshape([nframes, -1, 3])[:, :, None, :]
xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
- xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = np.eye(nloc, nall, dtype=bool)
diff[:, mask] = np.inf
rr2 = np.sum(np.square(diff), axis=-1)
min_rr2 = np.min(rr2, axis=-1)
mask = xp.eye(nloc, nall, dtype=xp.bool)
mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3))
diff = xp.where(mask, xp.full_like(diff, xp.inf), diff)
rr2 = xp.sum(xp.square(diff), axis=-1)
min_rr2 = xp.min(rr2, axis=-1)
# count the number of neighbors
if not self.mixed_types:
mask = rr2 < self.rcut**2
nnei = np.zeros((nframes, nloc, self.ntypes), dtype=int)
nneis = []
for ii in range(self.ntypes):
nnei[:, :, ii] = np.sum(
mask & (extend_atype == ii)[:, None, :], axis=-1
)
nneis.append(xp.sum(mask & (extend_atype == ii)[:, None, :], axis=-1))
nnei = xp.stack(nneis, axis=-1)
else:
mask = rr2 < self.rcut**2
# virtual type (<0) are not counted
nnei = np.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1).reshape(
nframes, nloc, 1
)
max_nnei = np.max(nnei, axis=1)
nnei = xp.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1)
nnei = xp.reshape(nnei, (nframes, nloc, 1))
max_nnei = xp.max(nnei, axis=1)
return min_rr2, max_nnei


Expand Down
59 changes: 59 additions & 0 deletions deepmd/jax/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import jaxlib

from deepmd.jax.env import (
jax,
)
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase


class AutoBatchSize(AutoBatchSizeBase):
"""Auto batch size.

Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor

"""

def __init__(
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
):
super().__init__(
initial_batch_size=initial_batch_size,
factor=factor,
)

def is_gpu_available(self) -> bool:
"""Check if GPU is available.

Returns
-------
bool
True if GPU is available
"""
return jax.devices()[0].platform == "gpu"

def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.

Parameters
----------
e : Exception
Exception
"""
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
# (the meaningless error message should be considered as a bug in cusolver)
if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
"RESOURCE_EXHAUSTED:" in e.args[0]
):
return True
return False
104 changes: 104 additions & 0 deletions deepmd/jax/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Iterator,
)
from typing import (
Optional,
)

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStatOP,
)
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.utils.auto_batch_size import (
AutoBatchSize,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat


class NeighborStat(BaseNeighborStat):
"""Neighbor statistics using JAX.

Parameters
----------
ntypes : int
The num of atom types
rcut : float
The cut-off radius
mixed_type : bool, optional, default=False
Treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
mixed_type: bool = False,
) -> None:
super().__init__(ntypes, rcut, mixed_type)
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
self.auto_batch_size = AutoBatchSize()

def iterator(
self, data: DeepmdDataSystem
) -> Iterator[tuple[np.ndarray, float, str]]:
"""Iterator method for producing neighbor statistics data.

Yields
------
np.ndarray
The maximal number of neighbors
float
The squared minimal distance between two atoms
str
The directory of the data system
"""
for ii in range(len(data.system_dirs)):
for jj in data.data_systems[ii].dirs:
data_set = data.data_systems[ii]
data_set_data = data_set._load_set(jj)
minrr2, max_nnei = self.auto_batch_size.execute_all(
self._execute,
data_set_data["coord"].shape[0],
data_set.get_natoms(),
data_set_data["coord"],
data_set_data["type"],
data_set_data["box"] if data_set.pbc else None,
)
yield np.max(max_nnei, axis=0), np.min(minrr2), jj

def _execute(
self,
coord: np.ndarray,
atype: np.ndarray,
cell: Optional[np.ndarray],
):
"""Execute the operation.

Parameters
----------
coord
The coordinates of atoms.
atype
The atom types.
cell
The cell.
"""
minrr2, max_nnei = self.op(
to_jax_array(coord),
to_jax_array(atype),
to_jax_array(cell),
)
minrr2 = to_numpy_array(minrr2)
max_nnei = to_numpy_array(max_nnei)
return minrr2, max_nnei
69 changes: 0 additions & 69 deletions source/tests/common/dpmodel/test_neighbor_stat.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from ..seed import (
GLOBAL_SEED,
)
from .common import (
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
)


def gen_sys(nframes):
Expand Down Expand Up @@ -42,7 +47,7 @@ def setUp(self):
def tearDown(self):
shutil.rmtree("system_0")

def test_neighbor_stat(self):
def run_neighbor_stat(self, backend):
for rcut in (0.0, 1.0, 2.0, 4.0):
for mixed_type in (True, False):
with self.subTest(rcut=rcut, mixed_type=mixed_type):
Expand All @@ -52,7 +57,7 @@ def test_neighbor_stat(self):
rcut=rcut,
type_map=["TYPE", "NO_THIS_TYPE"],
mixed_type=mixed_type,
backend="pytorch",
backend=backend,
)
upper = np.ceil(rcut) + 1
X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper]
Expand All @@ -67,3 +72,18 @@ def test_neighbor_stat(self):
if not mixed_type:
ret.append(0)
np.testing.assert_array_equal(max_nbor_size, ret)

@unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed")
def test_neighbor_stat_tf(self):
self.run_neighbor_stat("tensorflow")

@unittest.skipUnless(INSTALLED_PT, "pytorch is not installed")
def test_neighbor_stat_pt(self):
self.run_neighbor_stat("pytorch")

def test_neighbor_stat_dp(self):
self.run_neighbor_stat("numpy")

@unittest.skipUnless(INSTALLED_JAX, "jax is not installed")
def test_neighbor_stat_jax(self):
self.run_neighbor_stat("jax")
Loading