diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py index 8fc8791ee..5cb17d1a0 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py @@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator, #see_memory_usage(f'before forward {model.global_steps}', force=True) # Forward model for one step. timers('forward').start() - loss, loss_reduced = forward_step_func(data_iterator, model, args.curriculum_learning) + loss, loss_reduced = forward_step_func(data_iterator, model) timers('forward').stop() #see_memory_usage(f'before backward {model.global_steps}', force=True) diff --git a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py index 026702d95..86aac0c9a 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py +++ b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py @@ -86,7 +86,7 @@ def get_batch(data_iterator): return tokens, labels, loss_mask, attention_mask, position_ids -def forward_step(data_iterator, model, curriculum_learning=False): +def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() @@ -98,7 +98,7 @@ def forward_step(data_iterator, model, curriculum_learning=False): timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels=labels) - if curriculum_learning and args.curriculum_seqlen < args.seq_length: + if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()