From e03a38a934df747ed2d2b9e717c46a296a24cd33 Mon Sep 17 00:00:00 2001 From: learngit Date: Fri, 23 Jul 2021 09:24:32 +0800 Subject: [PATCH] Update losses.py avoid nan loss in SupCon --- losses.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/losses.py b/losses.py index 17117d42..911fd51b 100644 --- a/losses.py +++ b/losses.py @@ -89,7 +89,13 @@ def forward(self, features, labels=None, mask=None): log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # compute mean of log-likelihood over positive - mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + # avoid nan loss when there's one sample for a certain class, e.g., 0,1,...1 for bin-cls , this produce nan for 1st in Batch + # which also results in batch total loss as nan. such row should be dropped + pos_per_sample=mask.sum(1) #B + pos_per_sample[pos_per_sample<1e-6]=1.0 + mean_log_prob_pos = (mask * log_prob).sum(1) / pos_per_sample #mask.sum(1) + + #mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) # loss loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos