diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index b176592c4b63..3828be54a09c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -282,6 +282,8 @@ def training_step(self, batch, batch_idx): tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] # handle asynchronous grad reduction + custom_sync_context_handler = None + custom_grad_sync_func = None if self.with_distributed_adam: if self.megatron_amp_o2: # copy grads to main grad @@ -289,6 +291,7 @@ def training_step(self, batch, batch_idx): else: # keep grad tensors around custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) + custom_grad_sync_func = self.reduce_overlap_gradients else: if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): custom_sync_context_handler = self._optimizer.no_sync @@ -309,6 +312,7 @@ def training_step(self, batch, batch_idx): dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, custom_sync_context_handler=custom_sync_context_handler, + custom_grad_sync_func=custom_grad_sync_func, sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), sync_batch_comm=self.cfg.get('sync_batch_comm', True), num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( @@ -330,11 +334,8 @@ def training_step(self, batch, batch_idx): self.allreduce_sequence_parallel_gradients() if self.with_distributed_adam: - # launch grad reductions - # Note: grads in first pipeline stage have already been - # reduced - if not parallel_state.is_pipeline_first_stage(): - self.reduce_overlap_gradients() + # gradients are reduced internally in distributed optimizer + pass elif self.megatron_amp_o2: # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index b6d70dfb649e..a54a6362558e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -308,6 +308,8 @@ def training_step(self, batch, batch_idx): tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] # handle asynchronous grad reduction + custom_sync_context_handler = None + custom_grad_sync_func = None if self.with_distributed_adam: if self.megatron_amp_o2: # copy grads to main grad @@ -315,6 +317,7 @@ def training_step(self, batch, batch_idx): else: # keep grad tensors around custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) + custom_grad_sync_func = self.reduce_overlap_gradients else: if ( self.megatron_amp_o2 @@ -339,6 +342,7 @@ def training_step(self, batch, batch_idx): dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, custom_sync_context_handler=custom_sync_context_handler, + custom_grad_sync_func=custom_grad_sync_func, ) else: losses_reduced_per_micro_batch = forward_backward_no_pipelining( @@ -363,11 +367,8 @@ def training_step(self, batch, batch_idx): loss_mean = torch.tensor(0.0).cuda() if self.with_distributed_adam: - # launch grad reductions - # Note: grads in first pipeline stage have already been - # reduced - if not parallel_state.is_pipeline_first_stage(): - self.reduce_overlap_gradients() + # gradients are reduced internally in distributed optimizer + pass elif self.megatron_amp_o2: # when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously) if self.cfg.get('pipeline_model_parallel_size', 1) > 1: