-
Notifications
You must be signed in to change notification settings - Fork 599
feat(jax): neighbor stat #4258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat(jax): neighbor stat #4258
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
|
||
| import jaxlib | ||
|
|
||
| from deepmd.jax.env import ( | ||
| jax, | ||
| ) | ||
| from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase | ||
|
|
||
|
|
||
| class AutoBatchSize(AutoBatchSizeBase): | ||
| """Auto batch size. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| initial_batch_size : int, default: 1024 | ||
| initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE | ||
| is not set | ||
| factor : float, default: 2. | ||
| increased factor | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| initial_batch_size: int = 1024, | ||
| factor: float = 2.0, | ||
| ): | ||
| super().__init__( | ||
| initial_batch_size=initial_batch_size, | ||
| factor=factor, | ||
| ) | ||
|
|
||
| def is_gpu_available(self) -> bool: | ||
| """Check if GPU is available. | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| True if GPU is available | ||
| """ | ||
| return jax.devices()[0].platform == "gpu" | ||
njzjz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def is_oom_error(self, e: Exception) -> bool: | ||
| """Check if the exception is an OOM error. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| e : Exception | ||
| Exception | ||
| """ | ||
| # several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error, | ||
| # such as https://github.com/JuliaGPU/CUDA.jl/issues/1924 | ||
| # (the meaningless error message should be considered as a bug in cusolver) | ||
| if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and ( | ||
| "RESOURCE_EXHAUSTED:" in e.args[0] | ||
| ): | ||
| return True | ||
| return False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| from collections.abc import ( | ||
| Iterator, | ||
| ) | ||
| from typing import ( | ||
| Optional, | ||
| ) | ||
|
|
||
| import numpy as np | ||
|
|
||
| from deepmd.dpmodel.common import ( | ||
| to_numpy_array, | ||
| ) | ||
| from deepmd.dpmodel.utils.neighbor_stat import ( | ||
| NeighborStatOP, | ||
| ) | ||
| from deepmd.jax.common import ( | ||
| to_jax_array, | ||
| ) | ||
| from deepmd.jax.utils.auto_batch_size import ( | ||
| AutoBatchSize, | ||
| ) | ||
| from deepmd.utils.data_system import ( | ||
| DeepmdDataSystem, | ||
| ) | ||
| from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat | ||
|
|
||
|
|
||
| class NeighborStat(BaseNeighborStat): | ||
| """Neighbor statistics using JAX. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| ntypes : int | ||
| The num of atom types | ||
| rcut : float | ||
| The cut-off radius | ||
| mixed_type : bool, optional, default=False | ||
| Treat all types as a single type. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| ntypes: int, | ||
| rcut: float, | ||
| mixed_type: bool = False, | ||
| ) -> None: | ||
| super().__init__(ntypes, rcut, mixed_type) | ||
| self.op = NeighborStatOP(ntypes, rcut, mixed_type) | ||
| self.auto_batch_size = AutoBatchSize() | ||
|
|
||
| def iterator( | ||
| self, data: DeepmdDataSystem | ||
| ) -> Iterator[tuple[np.ndarray, float, str]]: | ||
| """Iterator method for producing neighbor statistics data. | ||
|
|
||
| Yields | ||
| ------ | ||
| np.ndarray | ||
| The maximal number of neighbors | ||
| float | ||
| The squared minimal distance between two atoms | ||
| str | ||
| The directory of the data system | ||
| """ | ||
| for ii in range(len(data.system_dirs)): | ||
| for jj in data.data_systems[ii].dirs: | ||
| data_set = data.data_systems[ii] | ||
| data_set_data = data_set._load_set(jj) | ||
njzjz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| minrr2, max_nnei = self.auto_batch_size.execute_all( | ||
njzjz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self._execute, | ||
| data_set_data["coord"].shape[0], | ||
| data_set.get_natoms(), | ||
| data_set_data["coord"], | ||
| data_set_data["type"], | ||
| data_set_data["box"] if data_set.pbc else None, | ||
| ) | ||
| yield np.max(max_nnei, axis=0), np.min(minrr2), jj | ||
|
|
||
| def _execute( | ||
| self, | ||
| coord: np.ndarray, | ||
| atype: np.ndarray, | ||
| cell: Optional[np.ndarray], | ||
| ): | ||
| """Execute the operation. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| coord | ||
| The coordinates of atoms. | ||
| atype | ||
| The atom types. | ||
| cell | ||
| The cell. | ||
| """ | ||
| minrr2, max_nnei = self.op( | ||
| to_jax_array(coord), | ||
| to_jax_array(atype), | ||
| to_jax_array(cell), | ||
| ) | ||
njzjz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| minrr2 = to_numpy_array(minrr2) | ||
| max_nnei = to_numpy_array(max_nnei) | ||
| return minrr2, max_nnei | ||
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.