From 4c4e17a27e1fff30addcecaa70c04371e2dee42e Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Mon, 25 Oct 2021 12:11:25 -0700 Subject: [PATCH] handle None case --- Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py index be4da2202..0671f393d 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py @@ -55,8 +55,8 @@ def __init__(self, num_tokentypes=0, parallel_output=True): def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None, curriculum_seqlen=None): + args = get_args() if curriculum_seqlen is not None: - args = get_args() args.curriculum_seqlen = curriculum_seqlen if curriculum_seqlen < input_ids.size()[1]: # seqlen-based curriculum learning @@ -67,6 +67,10 @@ def forward(self, input_ids, position_ids, attention_mask, labels=None, # attention_mask has size [1, 1, seqlen, seqlen] attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous() + else: + if args.curriculum_learning: + # If got a None input, need to reset curriculum_seqlen on user side + args.curriculum_seqlen = args.seq_length # Language model. lm_output = self.language_model(input_ids,