Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions megatron/training.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

Expand Down Expand Up @@ -251,14 +253,14 @@ def pretrain(train_valid_test_dataset_provider,

config = core_transformer_config_from_args(args)
if args.do_valid:
prefix = f'iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from validation set'
prefix = f'iteration {iteration} on {args.eval_iters * args.real_global_batch_size}-sample draw from validation set'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train)

if args.do_test:
prefix = f'iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from test set'
prefix = f'iteration {iteration} on {args.eval_iters * args.real_global_batch_size}-sample draw from test set'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
iteration, process_non_loss_data_func, config,
Expand All @@ -274,7 +276,7 @@ def update_train_iters(args):

# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
args.train_iters = args.train_samples // args.real_global_batch_size

else:
# Sample based training with rampup batch size.
Expand All @@ -283,14 +285,14 @@ def update_train_iters(args):
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
consumed_samples += args.real_global_batch_size
iterations += 1
# Reset
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.real_global_batch_size
args.train_iters = iterations

print_rank_0('setting training iterations to {}'.format(args.train_iters))
Expand Down Expand Up @@ -438,6 +440,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return model



def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
Expand All @@ -446,12 +449,12 @@ def get_optimizer_param_scheduler(optimizer):
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
lr_decay_steps = args.lr_decay_iters * args.real_global_batch_size
wd_incr_steps = args.train_iters * args.real_global_batch_size
if args.lr_warmup_fraction is not None:
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
lr_warmup_steps = args.lr_warmup_iters * args.real_global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
Expand Down Expand Up @@ -486,6 +489,14 @@ def get_optimizer_param_scheduler(optimizer):

return opt_param_scheduler









def load_model_weights_only(model_provider_func):
"""Setup model and optimizer."""
args = get_args()
Expand Down Expand Up @@ -588,7 +599,7 @@ def setup_model_and_optimizer(model_provider_func,
train_samples = args.train_samples
update_train_iters(args)
else:
train_samples = args.train_iters * args.global_batch_size
train_samples = args.train_iters * args.real_global_batch_size
# eval_iters and test_iters here are not actually used, only for
# satisfying the input of build_train_valid_test_datasets_provider.
# We only need to build the training data here. And we follow
Expand All @@ -598,8 +609,8 @@ def setup_model_and_optimizer(model_provider_func,
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
eval_iters * args.real_global_batch_size,
test_iters * args.real_global_batch_size]
# Build the datasets.
train_ds, _, _ = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
Expand Down Expand Up @@ -740,9 +751,7 @@ def train_step(forward_step_func, data_iterator,
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
if args.deepspeed:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
increment = args.real_global_batch_size
model[0].step(lr_kwargs={'increment': increment})
update_successful = model[0].was_step_applied()
else:
Expand Down Expand Up @@ -772,9 +781,7 @@ def train_step(forward_step_func, data_iterator,
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
else:
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
increment = args.real_global_batch_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
Expand Down Expand Up @@ -1274,9 +1281,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
config)
iteration += 1
args.iteration = iteration
new_samples = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
new_samples = args.real_global_batch_size
args.consumed_train_samples += new_samples
# This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging.
args.actual_seq_length = args.seq_length
Expand Down Expand Up @@ -1584,11 +1589,11 @@ def build_train_valid_test_data_loaders(
if args.iteration > 0 and args.consumed_train_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
args.consumed_train_samples = args.iteration * args.real_global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
if args.train_samples is None:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
args.eval_iters * args.real_global_batch_size

# Data loader only on rank 0 of each model parallel group.
ds_sequence_parallel = mpu.get_sequence_parallel_world_size() > 1 or args.force_ds_sequence_parallel
Expand Down Expand Up @@ -1664,4 +1669,4 @@ def build_train_valid_test_data_iterators(
else:
test_data_iterator = None

return train_data_iterator, valid_data_iterator, test_data_iterator
return train_data_iterator, valid_data_iterator, test_data_iterator
8 changes: 8 additions & 0 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ def loss_func(loss_mask, moe_loss, mos_loss, output_tensor):
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

#get Σm(
m_local = loss_mask.sum()
m_total = m_local.detach().clone()
torch.distributed.all_reduce(m_total, group=mpu.get_data_parallel_group())
dp = torch.distributed.get_world_size(group=mpu.get_data_parallel_group())


loss = loss * (m_local * dp) / (m_total)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
if args.mos or args.kd:
Expand Down