diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index bd071cf3f05e..75cea0bca417 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -234,6 +234,10 @@ def create_masked_lm_predictions( return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) + if num_to_predict < 1: + logging.warning( + F'Number of tokens is : {len(tokens)} and mask_probability is {masked_lm_prob}. None of the tokens will be masked' + ) ngrams = np.arange(1, max_ngram_size + 1, dtype=np.int64) if not geometric_dist: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index 0a850289301f..20da2a38f7ce 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -358,7 +358,13 @@ def loss_func(self, loss_mask, sentence_order, output_tensor): lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() - lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + # Sometimes when the number of tokens is very small, none of the tokens get masked for prediction. In that case loss mask is all zeros + # i.e Happens when the entire batch is masked out (Practically when MBS=1 or 2, and the number of tokens in each batch is < 7 ) + if loss_mask.sum() == 0: + lm_loss = torch.sum(lm_loss_.view(-1)) * 0.0 + else: + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() if sop_logits is not None: sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1)