diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index 356c1fae2b..ba8fa9adfa 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -1,3 +1,4 @@ +import warnings from typing import ( List, Optional, @@ -67,6 +68,8 @@ class DescrptSeAtten(DescrptSeA): exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. + set_davg_zero + Set the shift of embedding net input to zero. activation_function The activation function in the embedding net. Supported options are |ACTIVATION_FN| precision @@ -97,6 +100,7 @@ def __init__( trainable: bool = True, seed: Optional[int] = None, type_one_side: bool = True, + set_davg_zero: bool = True, exclude_types: List[List[int]] = [], activation_function: str = "tanh", precision: str = "default", @@ -107,6 +111,11 @@ def __init__( attn_mask: bool = False, multi_task: bool = False, ) -> None: + if not set_davg_zero: + warnings.warn( + "Set 'set_davg_zero' False in descriptor 'se_atten' " + "may cause unexpected incontinuity during model inference!" + ) DescrptSeA.__init__( self, rcut, @@ -119,7 +128,7 @@ def __init__( seed=seed, type_one_side=type_one_side, exclude_types=exclude_types, - set_davg_zero=True, + set_davg_zero=set_davg_zero, activation_function=activation_function, precision=precision, uniform_seed=uniform_seed, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 93135304a9..ca9afd43ac 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -325,6 +325,7 @@ def descrpt_se_atten_args(): doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." doc_trainable = "If the parameters in the embedding net is trainable" doc_seed = "Random seed for parameter initialization" + doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." doc_attn = "The length of hidden vectors in attention layers" doc_attn_layer = "The number of attention layers" @@ -361,6 +362,9 @@ def descrpt_se_atten_args(): Argument( "exclude_types", list, optional=True, default=[], doc=doc_exclude_types ), + Argument( + "set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero + ), Argument("attn", int, optional=True, default=128, doc=doc_attn), Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer), Argument("attn_dotr", bool, optional=True, default=True, doc=doc_attn_dotr),