Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions deepmd/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def build (self,
polar_hat = label_dict[self.label_name]
polar = model_dict[self.tensor_name]

# YWolfeee: get the 2 norm of label, i.e. polar_hat
normalized_term = tf.sqrt(tf.reduce_sum(tf.square(polar_hat)))

# YHT: added for global / local dipole combination
l2_loss = global_cvt_2_tf_float(0.0)
more_loss = {
Expand Down Expand Up @@ -117,7 +120,7 @@ def build (self,
self.l2_loss_global_summary = tf.summary.scalar('l2_global_loss',
tf.sqrt(more_loss['global_loss']) / global_cvt_2_tf_float(atoms))

# YHT: should only consider atoms with dipole, i.e. atoms
# YWolfeee: should only consider atoms with dipole, i.e. atoms
# atom_norm = 1./ global_cvt_2_tf_float(natoms[0])
atom_norm = 1./ global_cvt_2_tf_float(atoms)
global_loss *= atom_norm
Expand All @@ -128,7 +131,12 @@ def build (self,
self.l2_l = l2_loss

self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss))
return l2_loss, more_loss

# YWolfeee: loss normalization, do not influence the printed loss,
# just change the training process
#return l2_loss, more_loss
return l2_loss / normalized_term, more_loss


def eval(self, sess, feed_dict, natoms):
atoms = 0
Expand Down
27 changes: 24 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,19 +393,40 @@ def loss_ener():
Argument("relative_f", [float,None], optional = True, doc = doc_relative_f)
]

# YWolfeee: Modified to support tensor type of loss args.
def loss_tensor(default_mode):
if default_mode == "local":
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If not provided, training will be atomic mode, i.e. atomic label should be provided."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If it's not provided and global weight is provided, training will be global mode, i.e. global label should be provided. If both global and atomic weight are not provided, training will be atomic mode, i.e. atomic label should be provided."
return [
Argument("pref_weight", [float,int], optional = True, default = None, doc = doc_global_weight),
Argument("pref_atomic_weight", [float,int], optional = True, default = None, doc = doc_local_weight),
]
else:
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If not provided, training will be global mode, i.e. global label should be provided."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If it's not provided and atomic weight is provided, training will be atomic mode, i.e. atomic label should be provided. If both global and atomic weight are not provided, training will be global mode, i.e. global label should be provided."
return [
Argument("pref_weight", [float,int], optional = True, default = None, doc = doc_global_weight),
Argument("pref_atomic_weight", [float,int], optional = True, default = None, doc = doc_local_weight),
]

def loss_variant_type_args():
doc_loss = 'The type of the loss. \n\.'
doc_loss = 'The type of the loss. The loss type should be set to the fitting type or left unset.\n\.'


return Variant("type",
[Argument("ener", dict, loss_ener())],
[Argument("ener", dict, loss_ener()),
Argument("dipole", dict, loss_tensor("local")),
Argument("polar", dict, loss_tensor("local")),
Argument("global_polar", dict, loss_tensor("global"))
],
optional = True,
default_tag = 'ener',
doc = doc_loss)


def loss_args():
doc_loss = 'The definition of loss function. The type of the loss depends on the type of the fitting. For fitting type `ener`, the prefactors before energy, force, virial and atomic energy losses may be provided. For fitting type `dipole`, `polar` and `global_polar`, the loss may be an empty `dict` or unset.'
doc_loss = 'The definition of loss function. The loss type should be set to the fitting type or left unset.\n\.'
ca = Argument('loss', dict, [],
[loss_variant_type_args()],
optional = True,
Expand Down