Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,16 @@ def training_step(self, batch, batch_idx):
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# handle asynchronous grad reduction
custom_sync_context_handler = None
custom_grad_sync_func = None
if self.with_distributed_adam:
if self.megatron_amp_o2:
# copy grads to main grad
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True)
else:
# keep grad tensors around
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False)
custom_grad_sync_func = self.reduce_overlap_gradients
else:
if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False):
custom_sync_context_handler = self._optimizer.no_sync
Expand All @@ -309,6 +312,7 @@ def training_step(self, batch, batch_idx):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
custom_grad_sync_func=custom_grad_sync_func,
sequence_parallel_enabled=self.cfg.get('sequence_parallel', False),
sync_batch_comm=self.cfg.get('sync_batch_comm', True),
num_micro_batches_with_partial_activation_checkpoints=self.cfg.get(
Expand All @@ -330,11 +334,8 @@ def training_step(self, batch, batch_idx):
self.allreduce_sequence_parallel_gradients()

if self.with_distributed_adam:
# launch grad reductions
# Note: grads in first pipeline stage have already been
# reduced
if not parallel_state.is_pipeline_first_stage():
self.reduce_overlap_gradients()
# gradients are reduced internally in distributed optimizer
pass
elif self.megatron_amp_o2:
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,16 @@ def training_step(self, batch, batch_idx):
tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size]

# handle asynchronous grad reduction
custom_sync_context_handler = None
custom_grad_sync_func = None
if self.with_distributed_adam:
if self.megatron_amp_o2:
# copy grads to main grad
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True)
else:
# keep grad tensors around
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False)
custom_grad_sync_func = self.reduce_overlap_gradients
else:
if (
self.megatron_amp_o2
Expand All @@ -339,6 +342,7 @@ def training_step(self, batch, batch_idx):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
custom_grad_sync_func=custom_grad_sync_func,
)
else:
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
Expand All @@ -363,11 +367,8 @@ def training_step(self, batch, batch_idx):
loss_mean = torch.tensor(0.0).cuda()

if self.with_distributed_adam:
# launch grad reductions
# Note: grads in first pipeline stage have already been
# reduced
if not parallel_state.is_pipeline_first_stage():
self.reduce_overlap_gradients()
# gradients are reduced internally in distributed optimizer
pass
elif self.megatron_amp_o2:
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
Expand Down