From 8f83c2999f69f23ad4e5d2c7c4aa2abb98cd84e8 Mon Sep 17 00:00:00 2001 From: YWolfeee <1800017704@pku.edu.cn> Date: Tue, 20 Apr 2021 05:48:09 +0000 Subject: [PATCH] YWolfeee: modified to support loss normalization --- deepmd/loss/tensor.py | 12 ++++++++++-- deepmd/utils/argcheck.py | 26 +++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/deepmd/loss/tensor.py b/deepmd/loss/tensor.py index addccdcadf..256b98771d 100644 --- a/deepmd/loss/tensor.py +++ b/deepmd/loss/tensor.py @@ -78,6 +78,9 @@ def build (self, polar_hat = label_dict[self.label_name] polar = model_dict[self.tensor_name] + # YWolfeee: get the 2 norm of label, i.e. polar_hat + normalized_term = tf.sqrt(tf.reduce_sum(tf.square(polar_hat))) + # YHT: added for global / local dipole combination l2_loss = global_cvt_2_tf_float(0.0) more_loss = { @@ -117,7 +120,7 @@ def build (self, self.l2_loss_global_summary = tf.summary.scalar('l2_global_loss', tf.sqrt(more_loss['global_loss']) / global_cvt_2_tf_float(atoms)) - # YHT: should only consider atoms with dipole, i.e. atoms + # YWolfeee: should only consider atoms with dipole, i.e. atoms # atom_norm = 1./ global_cvt_2_tf_float(natoms[0]) atom_norm = 1./ global_cvt_2_tf_float(atoms) global_loss *= atom_norm @@ -128,7 +131,12 @@ def build (self, self.l2_l = l2_loss self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss)) - return l2_loss, more_loss + + # YWolfeee: loss normalization, do not influence the printed loss, + # just change the training process + #return l2_loss, more_loss + return l2_loss / normalized_term, more_loss + def print_header(self): prop_fmt = ' %11s %11s' diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 310090adfd..39dd6e53e9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -357,18 +357,38 @@ def loss_ener(): Argument("relative_f", [float,None], optional = True, doc = doc_relative_f) ] +# YWolfeee: Modified to support tensor type of loss args. +def loss_tensor(default_mode): + if default_mode == "local": + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If not provided, training will be atomic mode, i.e. atomic label should be provided." + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If it's not provided and global weight is provided, training will be global mode, i.e. global label should be provided. If both global and atomic weight are not provided, training will be atomic mode, i.e. atomic label should be provided." + return [ + Argument("pref_weight", [float,int], optional = True, default = None, doc = doc_global_weight), + Argument("pref_atomic_weight", [float,int], optional = True, default = None, doc = doc_local_weight), + ] + else: + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If not provided, training will be global mode, i.e. global label should be provided." + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If it's not provided and atomic weight is provided, training will be atomic mode, i.e. atomic label should be provided. If both global and atomic weight are not provided, training will be global mode, i.e. global label should be provided." + return [ + Argument("pref_weight", [float,int], optional = True, default = None, doc = doc_global_weight), + Argument("pref_atomic_weight", [float,int], optional = True, default = None, doc = doc_local_weight), + ] def loss_variant_type_args(): - doc_loss = 'The type of the loss. For fitting type `ener`, the loss type should be set to `ener` or left unset. For tensorial fitting types `dipole`, `polar` and `global_polar`, the type should be left unset.\n\.' + doc_loss = 'The type of the loss. The loss type should be set to the fitting type or left unset.\n\.' return Variant("type", - [Argument("ener", dict, loss_ener())], + [Argument("ener", dict, loss_ener()), + Argument("dipole", dict, loss_tensor("local")), + Argument("polar", dict, loss_tensor("local")), + Argument("global_polar", dict, loss_tensor("global")) + ], optional = True, default_tag = 'ener', doc = doc_loss) def loss_args(): - doc_loss = 'The definition of loss function. The type of the loss depends on the type of the fitting. For fitting type `ener`, the prefactors before energy, force, virial and atomic energy losses may be provided. For fitting type `dipole`, `polar` and `global_polar`, the loss may be an empty `dict` or unset.' + doc_loss = 'The definition of loss function. The loss type should be set to the fitting type or left unset.\n\.' ca = Argument('loss', dict, [], [loss_variant_type_args()], optional = True,