Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import (
List,
Optional,
Expand Down Expand Up @@ -67,6 +68,8 @@ class DescrptSeAtten(DescrptSeA):
exclude_types : List[List[int]]
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
set_davg_zero
Set the shift of embedding net input to zero.
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
Expand Down Expand Up @@ -97,6 +100,7 @@ def __init__(
trainable: bool = True,
seed: Optional[int] = None,
type_one_side: bool = True,
set_davg_zero: bool = True,
exclude_types: List[List[int]] = [],
activation_function: str = "tanh",
precision: str = "default",
Expand All @@ -107,6 +111,11 @@ def __init__(
attn_mask: bool = False,
multi_task: bool = False,
) -> None:
if not set_davg_zero:
warnings.warn(
"Set 'set_davg_zero' False in descriptor 'se_atten' "
"may cause unexpected incontinuity during model inference!"
)
DescrptSeA.__init__(
self,
rcut,
Expand All @@ -119,7 +128,7 @@ def __init__(
seed=seed,
type_one_side=type_one_side,
exclude_types=exclude_types,
set_davg_zero=True,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
uniform_seed=uniform_seed,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def descrpt_se_atten_args():
doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision."
doc_trainable = "If the parameters in the embedding net is trainable"
doc_seed = "Random seed for parameter initialization"
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used"
doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1."
doc_attn = "The length of hidden vectors in attention layers"
doc_attn_layer = "The number of attention layers"
Expand Down Expand Up @@ -361,6 +362,9 @@ def descrpt_se_atten_args():
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
),
Argument(
"set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero
),
Argument("attn", int, optional=True, default=128, doc=doc_attn),
Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer),
Argument("attn_dotr", bool, optional=True, default=True, doc=doc_attn_dotr),
Expand Down