diff --git a/deepmd/loss/tensor.py b/deepmd/loss/tensor.py index 1a97881c3a..91601d40e0 100644 --- a/deepmd/loss/tensor.py +++ b/deepmd/loss/tensor.py @@ -78,6 +78,9 @@ def build (self, polar_hat = label_dict[self.label_name] polar = model_dict[self.tensor_name] + # YWolfeee: get the 2 norm of label, i.e. polar_hat + normalized_term = tf.sqrt(tf.reduce_sum(tf.square(polar_hat))) + # YHT: added for global / local dipole combination l2_loss = global_cvt_2_tf_float(0.0) more_loss = { @@ -117,7 +120,7 @@ def build (self, self.l2_loss_global_summary = tf.summary.scalar('l2_global_loss', tf.sqrt(more_loss['global_loss']) / global_cvt_2_tf_float(atoms)) - # YHT: should only consider atoms with dipole, i.e. atoms + # YWolfeee: should only consider atoms with dipole, i.e. atoms # atom_norm = 1./ global_cvt_2_tf_float(natoms[0]) atom_norm = 1./ global_cvt_2_tf_float(atoms) global_loss *= atom_norm @@ -128,7 +131,12 @@ def build (self, self.l2_l = l2_loss self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss)) - return l2_loss, more_loss + + # YWolfeee: loss normalization, do not influence the printed loss, + # just change the training process + #return l2_loss, more_loss + return l2_loss / normalized_term, more_loss + def eval(self, sess, feed_dict, natoms): atoms = 0 diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1e151b7a1e..a29fd3b7f2 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -393,19 +393,40 @@ def loss_ener(): Argument("relative_f", [float,None], optional = True, doc = doc_relative_f) ] +# YWolfeee: Modified to support tensor type of loss args. +def loss_tensor(default_mode): + if default_mode == "local": + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If not provided, training will be atomic mode, i.e. atomic label should be provided." + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If it's not provided and global weight is provided, training will be global mode, i.e. global label should be provided. If both global and atomic weight are not provided, training will be atomic mode, i.e. atomic label should be provided." + return [ + Argument("pref_weight", [float,int], optional = True, default = None, doc = doc_global_weight), + Argument("pref_atomic_weight", [float,int], optional = True, default = None, doc = doc_local_weight), + ] + else: + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If not provided, training will be global mode, i.e. global label should be provided." + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If it's not provided and atomic weight is provided, training will be atomic mode, i.e. atomic label should be provided. If both global and atomic weight are not provided, training will be global mode, i.e. global label should be provided." + return [ + Argument("pref_weight", [float,int], optional = True, default = None, doc = doc_global_weight), + Argument("pref_atomic_weight", [float,int], optional = True, default = None, doc = doc_local_weight), + ] def loss_variant_type_args(): - doc_loss = 'The type of the loss. \n\.' + doc_loss = 'The type of the loss. The loss type should be set to the fitting type or left unset.\n\.' + return Variant("type", - [Argument("ener", dict, loss_ener())], + [Argument("ener", dict, loss_ener()), + Argument("dipole", dict, loss_tensor("local")), + Argument("polar", dict, loss_tensor("local")), + Argument("global_polar", dict, loss_tensor("global")) + ], optional = True, default_tag = 'ener', doc = doc_loss) def loss_args(): - doc_loss = 'The definition of loss function. The type of the loss depends on the type of the fitting. For fitting type `ener`, the prefactors before energy, force, virial and atomic energy losses may be provided. For fitting type `dipole`, `polar` and `global_polar`, the loss may be an empty `dict` or unset.' + doc_loss = 'The definition of loss function. The loss type should be set to the fitting type or left unset.\n\.' ca = Argument('loss', dict, [], [loss_variant_type_args()], optional = True,