diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 061cab1dc8..41cb3ca333 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -689,12 +689,14 @@ def build( outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]]) # add bias self.atom_ener_before = outs * atype_filter - self.add_type = tf.reshape( + # atomic bias energy from data statistics + self.atom_bias_ener = tf.reshape( tf.nn.embedding_lookup(self.t_bias_atom_e, self.atype_nloc), [tf.shape(inputs)[0], tf.reduce_sum(natoms[2 : 2 + ntypes_atom])], ) - outs = outs + self.add_type + outs = outs + self.atom_bias_ener outs *= atype_filter + self.atom_bias_ener *= atype_filter self.atom_ener_after = outs if self.tot_ener_zero: diff --git a/deepmd/model/ener.py b/deepmd/model/ener.py index 887343e75a..979599b14c 100644 --- a/deepmd/model/ener.py +++ b/deepmd/model/ener.py @@ -58,6 +58,8 @@ class EnerModel(StandardModel): The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. sw_rmin The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. + srtab_add_bias : bool + Whether add energy bias from the statistics of the data to short-range tabulated atomic energy. It only takes effect when `use_srtab` is provided. spin spin data_stat_nsample @@ -78,6 +80,7 @@ def __init__( smin_alpha: Optional[float] = None, sw_rmin: Optional[float] = None, sw_rmax: Optional[float] = None, + srtab_add_bias: bool = True, spin: Optional[Spin] = None, data_bias_nsample: int = 10, **kwargs, @@ -96,6 +99,7 @@ def __init__( sw_rmin=sw_rmin, sw_rmax=sw_rmax, spin=spin, + srtab_add_bias=srtab_add_bias, **kwargs, ) self.numb_fparam = self.fitting.get_numb_fparam() @@ -263,6 +267,8 @@ def build( sel_a=sel_a, sel_r=sel_r, ) + if self.srtab_add_bias: + tab_atom_ener += self.fitting.atom_bias_ener energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]]) tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( tab_atom_ener, [-1] diff --git a/deepmd/model/model.py b/deepmd/model/model.py index c4baf5df54..4de06dc42f 100644 --- a/deepmd/model/model.py +++ b/deepmd/model/model.py @@ -69,6 +69,8 @@ class Model(ABC): The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. sw_rmin The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. + srtab_add_bias : bool + Whether add energy bias from the statistics of the data to short-range tabulated atomic energy. It only takes effect when `use_srtab` is provided. spin spin compress @@ -104,6 +106,7 @@ def __init__( smin_alpha: Optional[float] = None, sw_rmin: Optional[float] = None, sw_rmax: Optional[float] = None, + srtab_add_bias: bool = True, spin: Optional[Spin] = None, compress: Optional[dict] = None, **kwargs, @@ -131,6 +134,7 @@ def __init__( self.smin_alpha = smin_alpha self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax + self.srtab_add_bias = srtab_add_bias else: self.srtab = None diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 392e0f8907..28ac5747b8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -758,6 +758,7 @@ def model_args(): doc_smin_alpha = "The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided." doc_sw_rmin = "The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided." doc_sw_rmax = "The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided." + doc_srtab_add_bias = "Whether add energy bias from the statistics of the data to short-range tabulated atomic energy. It only takes effect when `use_srtab` is provided." doc_compress_config = "Model compression configurations" doc_spin = "The settings for systems with spin." return Argument( @@ -790,6 +791,13 @@ def model_args(): Argument("smin_alpha", float, optional=True, doc=doc_smin_alpha), Argument("sw_rmin", float, optional=True, doc=doc_sw_rmin), Argument("sw_rmax", float, optional=True, doc=doc_sw_rmax), + Argument( + "srtab_add_bias", + bool, + optional=True, + default=True, + doc=doc_srtab_add_bias, + ), Argument( "type_embedding", dict,