From 1e1c7a18d35a54bb0c05bb8ed5a6c20aada99ea8 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Fri, 22 Oct 2021 11:57:58 -0700 Subject: [PATCH] fix breaking api --- Megatron-LM-v1.1.5-ZeRO3/megatron/training.py | 2 +- Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py index 8fc8791ee..5cb17d1a0 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py @@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator, #see_memory_usage(f'before forward {model.global_steps}', force=True) # Forward model for one step. timers('forward').start() - loss, loss_reduced = forward_step_func(data_iterator, model, args.curriculum_learning) + loss, loss_reduced = forward_step_func(data_iterator, model) timers('forward').stop() #see_memory_usage(f'before backward {model.global_steps}', force=True) diff --git a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py index 026702d95..86aac0c9a 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py +++ b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py @@ -86,7 +86,7 @@ def get_batch(data_iterator): return tokens, labels, loss_mask, attention_mask, position_ids -def forward_step(data_iterator, model, curriculum_learning=False): +def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() @@ -98,7 +98,7 @@ def forward_step(data_iterator, model, curriculum_learning=False): timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels=labels) - if curriculum_learning and args.curriculum_seqlen < args.seq_length: + if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()