diff --git a/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py b/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py index dcf461d5334f..9608a0320bd6 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py @@ -539,6 +539,7 @@ def generate_candidates(self, labels, template_length, input_ids, attn_masks): for i in generated_tokens ] generated_tokens = torch.cat(generated_tokens, axis=0) + num_prompt_tokens = 0 elif self.cfg.library == "megatron":