diff --git a/megatron/training.py b/megatron/training.py index cd789412c5d..abb5acaeb6e 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1,3 +1,5 @@ + + # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. @@ -251,14 +253,14 @@ def pretrain(train_valid_test_dataset_provider, config = core_transformer_config_from_args(args) if args.do_valid: - prefix = f'iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from validation set' + prefix = f'iteration {iteration} on {args.eval_iters * args.real_global_batch_size}-sample draw from validation set' evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, process_non_loss_data_func, config, verbose=True, write_to_tensorboard=not args.skip_train) if args.do_test: - prefix = f'iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from test set' + prefix = f'iteration {iteration} on {args.eval_iters * args.real_global_batch_size}-sample draw from test set' evaluate_and_print_results(prefix, forward_step_func, test_data_iterator, model, iteration, process_non_loss_data_func, config, @@ -274,7 +276,7 @@ def update_train_iters(args): # Constant batch size with sample-based training. if args.rampup_batch_size is None: - args.train_iters = args.train_samples // args.global_batch_size + args.train_iters = args.train_samples // args.real_global_batch_size else: # Sample based training with rampup batch size. @@ -283,14 +285,14 @@ def update_train_iters(args): # Rampup phase. while consumed_samples <= int(args.rampup_batch_size[2]): update_num_microbatches(consumed_samples, consistency_check=False) - consumed_samples += get_current_global_batch_size() + consumed_samples += args.real_global_batch_size iterations += 1 # Reset update_num_microbatches(0, consistency_check=False) # Constant phase # Note that we throw away any partial last batch. iterations += (args.train_samples - consumed_samples) // \ - args.global_batch_size + args.real_global_batch_size args.train_iters = iterations print_rank_0('setting training iterations to {}'.format(args.train_iters)) @@ -438,6 +440,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap return model + def get_optimizer_param_scheduler(optimizer): """Build the learning rate scheduler.""" args = get_args() @@ -446,12 +449,12 @@ def get_optimizer_param_scheduler(optimizer): if args.train_iters: if args.lr_decay_iters is None: args.lr_decay_iters = args.train_iters - lr_decay_steps = args.lr_decay_iters * args.global_batch_size - wd_incr_steps = args.train_iters * args.global_batch_size + lr_decay_steps = args.lr_decay_iters * args.real_global_batch_size + wd_incr_steps = args.train_iters * args.real_global_batch_size if args.lr_warmup_fraction is not None: lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps else: - lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size + lr_warmup_steps = args.lr_warmup_iters * args.real_global_batch_size # Sample-based training. elif args.train_samples: # We need to set training iters for later use. Technically @@ -486,6 +489,14 @@ def get_optimizer_param_scheduler(optimizer): return opt_param_scheduler + + + + + + + + def load_model_weights_only(model_provider_func): """Setup model and optimizer.""" args = get_args() @@ -588,7 +599,7 @@ def setup_model_and_optimizer(model_provider_func, train_samples = args.train_samples update_train_iters(args) else: - train_samples = args.train_iters * args.global_batch_size + train_samples = args.train_iters * args.real_global_batch_size # eval_iters and test_iters here are not actually used, only for # satisfying the input of build_train_valid_test_datasets_provider. # We only need to build the training data here. And we follow @@ -598,8 +609,8 @@ def setup_model_and_optimizer(model_provider_func, args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_samples, - eval_iters * args.global_batch_size, - test_iters * args.global_batch_size] + eval_iters * args.real_global_batch_size, + test_iters * args.real_global_batch_size] # Build the datasets. train_ds, _, _ = build_train_valid_test_datasets_provider( train_val_test_num_samples) @@ -740,9 +751,7 @@ def train_step(forward_step_func, data_iterator, # Update parameters. timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) if args.deepspeed: - increment = get_num_microbatches() * \ - args.micro_batch_size * \ - args.data_parallel_size + increment = args.real_global_batch_size model[0].step(lr_kwargs={'increment': increment}) update_successful = model[0].was_step_applied() else: @@ -772,9 +781,7 @@ def train_step(forward_step_func, data_iterator, return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad else: if update_successful: - increment = get_num_microbatches() * \ - args.micro_batch_size * \ - args.data_parallel_size + increment = args.real_global_batch_size opt_param_scheduler.step(increment=increment) skipped_iter = 0 else: @@ -1274,9 +1281,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, config) iteration += 1 args.iteration = iteration - new_samples = mpu.get_data_parallel_world_size() * \ - args.micro_batch_size * \ - get_num_microbatches() + new_samples = args.real_global_batch_size args.consumed_train_samples += new_samples # This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging. args.actual_seq_length = args.seq_length @@ -1584,11 +1589,11 @@ def build_train_valid_test_data_loaders( if args.iteration > 0 and args.consumed_train_samples == 0: assert args.train_samples is None, \ 'only backward compatiblity support for iteration-based training' - args.consumed_train_samples = args.iteration * args.global_batch_size + args.consumed_train_samples = args.iteration * args.real_global_batch_size if args.iteration > 0 and args.consumed_valid_samples == 0: if args.train_samples is None: args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ - args.eval_iters * args.global_batch_size + args.eval_iters * args.real_global_batch_size # Data loader only on rank 0 of each model parallel group. ds_sequence_parallel = mpu.get_sequence_parallel_world_size() > 1 or args.force_ds_sequence_parallel @@ -1664,4 +1669,4 @@ def build_train_valid_test_data_iterators( else: test_data_iterator = None - return train_data_iterator, valid_data_iterator, test_data_iterator + return train_data_iterator, valid_data_iterator, test_data_iterator \ No newline at end of file diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 88587d7e427..e55e736ac4a 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -215,6 +215,14 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + #get Σm( + m_local = loss_mask.sum() + m_total = m_local.detach().clone() + torch.distributed.all_reduce(m_total, group=mpu.get_data_parallel_group()) + dp = torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) + + + loss = loss * (m_local * dp) / (m_total) # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) if args.mos or args.kd: