From dbc40ef6d29d787c7e9fc8155e470cbb4423b1e5 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 15 Mar 2023 17:31:47 +0800 Subject: [PATCH 1/2] Fix incompatibility after fixing the incontinuity of se_atten Fix incompatibility after fixing the incontinuity of se_atten --- deepmd/descriptor/se_atten.py | 9 ++++++++- deepmd/utils/argcheck.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index 356c1fae2b..0c7e213038 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -5,6 +5,7 @@ ) import numpy as np +import warnings from packaging.version import ( Version, ) @@ -67,6 +68,8 @@ class DescrptSeAtten(DescrptSeA): exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. + set_davg_zero + Set the shift of embedding net input to zero. activation_function The activation function in the embedding net. Supported options are |ACTIVATION_FN| precision @@ -97,6 +100,7 @@ def __init__( trainable: bool = True, seed: Optional[int] = None, type_one_side: bool = True, + set_davg_zero: bool = True, exclude_types: List[List[int]] = [], activation_function: str = "tanh", precision: str = "default", @@ -107,6 +111,9 @@ def __init__( attn_mask: bool = False, multi_task: bool = False, ) -> None: + if not set_davg_zero: + warnings.warn("Set 'set_davg_zero' False in descriptor 'se_atten' " + "may cause unexpected incontinuity during model inference!") DescrptSeA.__init__( self, rcut, @@ -119,7 +126,7 @@ def __init__( seed=seed, type_one_side=type_one_side, exclude_types=exclude_types, - set_davg_zero=True, + set_davg_zero=set_davg_zero, activation_function=activation_function, precision=precision, uniform_seed=uniform_seed, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 93135304a9..ca9afd43ac 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -325,6 +325,7 @@ def descrpt_se_atten_args(): doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." doc_trainable = "If the parameters in the embedding net is trainable" doc_seed = "Random seed for parameter initialization" + doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." doc_attn = "The length of hidden vectors in attention layers" doc_attn_layer = "The number of attention layers" @@ -361,6 +362,9 @@ def descrpt_se_atten_args(): Argument( "exclude_types", list, optional=True, default=[], doc=doc_exclude_types ), + Argument( + "set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero + ), Argument("attn", int, optional=True, default=128, doc=doc_attn), Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer), Argument("attn_dotr", bool, optional=True, default=True, doc=doc_attn_dotr), From 1a116c6cc470791fc14b93ec4797f675a5f464f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Mar 2023 09:34:22 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/descriptor/se_atten.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index 0c7e213038..ba8fa9adfa 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -1,3 +1,4 @@ +import warnings from typing import ( List, Optional, @@ -5,7 +6,6 @@ ) import numpy as np -import warnings from packaging.version import ( Version, ) @@ -112,8 +112,10 @@ def __init__( multi_task: bool = False, ) -> None: if not set_davg_zero: - warnings.warn("Set 'set_davg_zero' False in descriptor 'se_atten' " - "may cause unexpected incontinuity during model inference!") + warnings.warn( + "Set 'set_davg_zero' False in descriptor 'se_atten' " + "may cause unexpected incontinuity during model inference!" + ) DescrptSeA.__init__( self, rcut,