diff --git a/Dockerfile b/Dockerfile index 42686059bbca..cd78ef4348e1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -32,12 +32,11 @@ RUN apt-get update && \ python-dev ffmpeg && \ rm -rf /var/lib/apt/lists/* -# FIXME a workaround to update apex. Remove when base image is updated WORKDIR /tmp/ RUN git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout 3c19f1061879394f28272a99a7ea26d58f72dace && \ - pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ + git checkout 2b0e8371113fe70758f1964c40bf7dbe304fd9e6 && \ + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./ # uninstall stuff from base container RUN pip uninstall -y sacrebleu torchtext diff --git a/examples/nlp/language_modeling/megatron_t5_pretraining.py b/examples/nlp/language_modeling/megatron_t5_pretraining.py index 462cc62d28eb..4f044cb3c34d 100644 --- a/examples/nlp/language_modeling/megatron_t5_pretraining.py +++ b/examples/nlp/language_modeling/megatron_t5_pretraining.py @@ -38,6 +38,7 @@ def main(cfg) -> None: logging.info(f'\n{OmegaConf.to_yaml(cfg)}') megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) + with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam' plugins = [] strategy = NLPDDPStrategy( no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce @@ -52,7 +53,7 @@ def main(cfg) -> None: growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) - if megatron_amp_o2: + if megatron_amp_o2 and not with_distributed_adam: plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) else: plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 8091567ead62..e3bb9ad9e675 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -14,7 +14,7 @@ import os import re -from typing import Optional +from typing import Any, Dict, Optional, Union import torch from omegaconf import open_dict @@ -53,9 +53,13 @@ class MegatronBaseModel(NLPModel): 1. Initialize the model parallel for nemo given the model parallel parameters. 2. Turn on all the nvidia optimizations. 3. If `cfg.tokenizer` is available, it loads the tokenizer and pad the vocab to the correct size for tensor model parallelism. - 4. It help to run `configure_gradient_clipping`, if `grad_clip_pl_default` is set True, it uses the pytorch lightning default - gradient clipping. Or if `megatron_amp_o2` is set True, it uses the parameters from optimizer to clip the gradients. - Otherwise, it uses the parameters calculated in the `setup_optimizer_param_groups` method. + 4. If using distributed optimizer, configure to be compatible with + O2-level optimizations and/or model parallelism. + 5. Perform gradient clipping: `grad_clip_pl_default` triggers the + PyTorch Lightning default implementation, `with_distributed_adam` + triggers the distributed optimizer's implementation, + `megatron_amp_o2` triggers gradient clipping on the main grads, + and otherwise gradient clipping is performed on the model grads. """ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): @@ -73,6 +77,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): self._validate_config() + self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam' + # used in NVIDIA NGC PyTorch containers self._enable_nvidia_optimizations() @@ -220,7 +226,7 @@ def configure_gradient_clipping(self, *args, **kwargs): # use the default behavior return super().configure_gradient_clipping(*args, **kwargs) - if hasattr(self, 'with_distributed_adam') and self.with_distributed_adam: + if self.with_distributed_adam: grad_norm = clip_grad_norm_distributed_optimizer(self._optimizer, clip_val) else: if self.megatron_amp_o2: @@ -256,6 +262,20 @@ def allreduce_gradients(self): for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) + def reduce_overlap_gradients(self): + """Reduce grads if overlapped grad sync is enabled + + Used for pipeline parallelism with the distributed Adam + optimizer. In the first pipeline stage, the grad sync is + overlapped with the final backward pass. In other pipeline + stages, the grad sync is deferred until the bubble overhead. + + """ + if self.with_distributed_adam: + self._optimizer.try_grad_sync( + p for p in self._optimizer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False) + ) + def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None: super().on_train_batch_end(outputs, batch, batch_idx) @@ -294,15 +314,37 @@ def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[in # accumulated gradient updates. grad_scaler.optimizer_update_skipped = None + def setup_optimization( + self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, + ): + optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() + if self.with_distributed_adam: + + # Allocate grads since we are storing between microbatches + optim_kwargs['contiguous_grad_buffer'] = True + + if self.megatron_amp_o2: + # Match param allgather with model dtype + if hasattr(self, 'autocast_dtype'): + optim_kwargs['param_sync_dtype'] = self.autocast_dtype + if self.autocast_dtype == torch.float: + optim_kwargs['store_params'] = False + elif self.autocast_dtype == torch.float16: + optim_kwargs['store_params'] = True + elif self.autocast_dtype == torch.bfloat16: + optim_kwargs['store_params'] = False + optim_kwargs['store_param_remainders'] = True + else: + # Assume FP32 params, so no need to store main params + optim_kwargs['store_params'] = False + + return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs) + def configure_optimizers(self): self.setup_optimization() # Wrap the baseline optimizer with the optimizer class with master parameters - if ( - self.megatron_amp_o2 - and not (hasattr(self, 'with_distributed_adam') and self.with_distributed_adam) - and self._optimizer is not None - ): + if self.megatron_amp_o2 and not self.with_distributed_adam and self._optimizer is not None: if self.cfg.precision == 'bf16': fp32_grad_accum = True contiguous_grad_bucket = True @@ -347,6 +389,16 @@ def configure_optimizers(self): optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl ) + # Configure distributed optimizer + if self.with_distributed_adam: + + # Initialize params so that main grads are available + # Note: Consolidate grads without overlap + self._optimizer.init_params( + p for p in self.parameters() if getattr(p, '_disable_overlap_grad_sync', False) + ) + self._optimizer.init_params(self.parameters()) + if self._scheduler is None: return self._optimizer else: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 94e1960101ba..bb8bd73f996e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +import itertools +from typing import Any, List, Optional, Union import numpy as np import torch @@ -82,16 +83,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self._validate_trainer() self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False) - self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam' if not self.megatron_amp_o2 and self.cfg.get('virtual_pipeline_model_parallel_size', None): raise ValueError('Virtual pipeline model parallel is only supported when using megatron_amp_O2') - if self.with_distributed_adam and not self.megatron_amp_o2: - raise ValueError( - "Distributed optimizers require O2. Please set megatron_amp_O2 to True in the model config." - ) - # build_model returns a list of modules which are used for interleaved pipeline parallelism self.model = build_model( model_provider_func=self.model_provider_func, @@ -186,15 +181,40 @@ def setup_optimizer_param_groups(self): else: self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model) - def setup_optimization( - self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, - ): - optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() + def configure_optimizers(self): + if self.with_distributed_adam: - optim_kwargs['process_group'] = parallel_state.get_data_parallel_group() - optim_kwargs['param_sync_dtype'] = self.autocast_dtype - optim_kwargs['contiguous_grad_buffer'] = True - return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs) + + # Disable overlapped grad sync for embedding grad when + # pipeline parallelism is enabled + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + if isinstance(self.model, list): + module = self.model[0] # only the first virtual rank has the embeddings + else: + module = self.model + if module.share_token_embeddings: + param = module.word_embeddings_weight() + param._disable_greedy_grad_copy = not self.megatron_amp_o2 + param._disable_overlap_grad_sync = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + if isinstance(self.model, list): + module = self.model[-1] # only the last virtual rank has the embeddings + else: + module = self.model + if module.share_token_embeddings: + param = module.word_embeddings_weight() + param._disable_greedy_grad_copy = not self.megatron_amp_o2 + param._disable_overlap_grad_sync = True + + # Disable overlapped grad sync for layer norm grads when + # sequence parallelism is enabled + for param in self.parameters(): + if getattr(param, 'sequence_parallel_enabled', False): + param._disable_greedy_grad_copy = not self.megatron_amp_o2 + param._disable_overlap_grad_sync = True + + return super().configure_optimizers() def forward(self, tokens, text_position_ids, attention_mask, labels): output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels) @@ -236,16 +256,20 @@ def training_step(self, batch, batch_idx): tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] - # determine if we can use async grad all reduce - custom_sync_context_handler = None - if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): - if self.with_distributed_adam: + # handle asynchronous grad reduction + if self.with_distributed_adam: + if self.megatron_amp_o2: + # copy grads to main grad custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) else: - custom_sync_context_handler = self._optimizer.no_sync + # keep grad tensors around + custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) else: - # TODO: enable async grad all reduce for O1/autocast mixed precision training - custom_sync_context_handler = None + if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): + custom_sync_context_handler = self._optimizer.no_sync + else: + # TODO: enable async grad all reduce for O1/autocast mixed precision training + custom_sync_context_handler = None # run forward and backwards passes for an entire global batch # we do this inside training_step to support pipeline parallelism @@ -277,8 +301,11 @@ def training_step(self, batch, batch_idx): self.allreduce_sequence_parallel_gradients() if self.with_distributed_adam: - # gradients are reduced internally in distributed optimizer - pass + # launch grad reductions + # Note: grads in first pipeline stage have already been + # reduced + if not parallel_state.is_pipeline_first_stage(): + self.reduce_overlap_gradients() elif self.megatron_amp_o2: # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): @@ -753,23 +780,6 @@ def setup_test_data(self, cfg): ) self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples) - def configure_optimizers(self): - retval = super().configure_optimizers() - - if self.with_distributed_adam: - - # Initialize params in reverse order - # Note: Estimate order in which grads are generated in - # backward pass - self._optimizer.init_params(reversed(list(self.parameters()))) - - # Overlapped communication interferes with grad reductions - # for pipeline parallelism and sequence parallelism - if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): - self._optimizer.overlap_grad_sync = False - - return retval - def generate( self, inputs: Union[List[str], torch.Tensor, List[dict]], @@ -878,3 +888,9 @@ def on_load_checkpoint(self, checkpoint) -> None: parallel_state.set_virtual_pipeline_model_parallel_rank(i) self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True) parallel_state.set_virtual_pipeline_model_parallel_rank(0) + + def parameters(self): + if isinstance(self.model, list): + return itertools.chain.from_iterable(module.parameters() for module in self.model) + else: + return self.model.parameters() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 02ae2030aca5..b6d70dfb649e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -101,8 +101,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if self.megatron_amp_o2: - # Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type - self.enc_dec_model.cuda(torch.cuda.current_device()) + if not self.with_distributed_adam: + # Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type + self.enc_dec_model.cuda(torch.cuda.current_device()) # Model wrapper to convert both model and inputs to half precision self.enc_dec_model = Float16Module(module=self.enc_dec_model, precision=cfg.precision) @@ -122,6 +123,58 @@ def setup_optimizer_param_groups(self): """ModelPT override. Optimizer will get self._optimizer_param_groups""" self._optimizer_param_groups = get_params_for_weight_decay_optimization([self.enc_dec_model]) + def configure_optimizers(self): + + if self.with_distributed_adam: + + # Identify params that require grad reductions between + # pipeline stages + # See: allreduce_word_and_position_embeddings + model_parallel_params = [] + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and ( + parallel_state.is_rank_in_embedding_group() + ): + if self.cfg.get('share_token_embeddings', True) and self.cfg.get( + 'share_decoder_tokens_head_embeddings', True + ): + model_parallel_params.append(self.enc_dec_model.word_embeddings_weight()) + if ( + parallel_state.is_rank_in_position_embedding_group() + and parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_pipeline_model_parallel_split_rank() is not None + and self.cfg.encoder.get('position_embedding_type') == 'learned_absolute' + and self.cfg.decoder.get('position_embedding_type') == 'learned_absolute' + ): + if self.cfg.get('share_token_embeddings', True): + model_parallel_params.append(self.enc_dec_model.position_embeddings_weight()) + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 2 + and parallel_state.get_pipeline_model_parallel_split_rank() is not None + ): + if ( + self.cfg.encoder.get('position_embedding_type') == 'relative' + and parallel_state.is_rank_in_encoder_relative_position_embedding_group() + and parallel_state.get_pipeline_model_parallel_split_rank() > 1 + ): + model_parallel_params.append(self.enc_dec_model.encoder_relative_position_embeddings_weight()) + if ( + self.cfg.decoder.get('position_embedding_type') == 'relative' + and parallel_state.is_rank_in_decoder_relative_position_embedding_group() + ): + model_parallel_params.append(self.enc_dec_model.decoder_relative_position_embeddings_weight()) + if not self.cfg.decoder.get('relative_position_bias_self_attention_only', True): + model_parallel_params.append( + self.enc_dec_model.decoder_cross_attention_relative_position_embeddings_weight() + ) + + # Disable async grad reductions for params that are + # synchronized for pipeline parallelism + for param in model_parallel_params: + param._disable_greedy_grad_copy = not self.megatron_amp_o2 + param._disable_overlap_grad_sync = True + + return super().configure_optimizers() + def _handle_bias_activation_fusion_args(self, cfg): # For oldest models, we don't have the option to turn on/off bias activation fusion. It is always on. if not hasattr(cfg, 'bias_gelu_fusion') and not hasattr(cfg, 'bias_activation_fusion'): @@ -254,6 +307,27 @@ def training_step(self, batch, batch_idx): decoder_seq_length = batch_for_pipeline[1].size(1) tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] + # handle asynchronous grad reduction + if self.with_distributed_adam: + if self.megatron_amp_o2: + # copy grads to main grad + custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) + else: + # keep grad tensors around + custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) + else: + if ( + self.megatron_amp_o2 + and self.cfg.get('pipeline_model_parallel_size', 1) == 1 + and not self.cfg.get('sequence_parallel', False) + ): + custom_sync_context_handler = self._optimizer.no_sync + else: + # TODO: enable async grad all reduce with O1/autocast + # mixed precision training, with pipeline parallelism, + # or with sequence parallelism + custom_sync_context_handler = None + if self.cfg.get('pipeline_model_parallel_size', 1) > 1: losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_and_loss_func(), @@ -264,14 +338,9 @@ def training_step(self, batch, batch_idx): decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + custom_sync_context_handler=custom_sync_context_handler, ) else: - # no pipeline parallelism so we reduce grads asynchronously - if self.megatron_amp_o2: - custom_sync_context_handler = self._optimizer.no_sync - else: - # TODO: enable async grad all reduce for O1/autocast mixed precision training - custom_sync_context_handler = None losses_reduced_per_micro_batch = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_and_loss_func(), batch=batch_for_pipeline, @@ -293,7 +362,13 @@ def training_step(self, batch, batch_idx): else: loss_mean = torch.tensor(0.0).cuda() - if self.megatron_amp_o2: + if self.with_distributed_adam: + # launch grad reductions + # Note: grads in first pipeline stage have already been + # reduced + if not parallel_state.is_pipeline_first_stage(): + self.reduce_overlap_gradients() + elif self.megatron_amp_o2: # when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously) if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # main grads are stored in the MainParamsOptimizer wrapper diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py index 0d5e49b5a76f..97efa254459a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py @@ -86,8 +86,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if self.megatron_amp_o2: - # Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type - self.model.cuda(torch.cuda.current_device()) + if not self.with_distributed_adam: + # Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type + self.model.cuda(torch.cuda.current_device()) # Model wrapper to convert both model and inputs to half precision self.model = Float16Module(module=self.model, precision=self.cfg.precision) @@ -253,13 +254,16 @@ def training_step(self, batch, batch_idx): loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: self.log('loss_scale', loss_scale) - # while async grad allreduce is enabled, bprop will keep moving forward without waiting for - # the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction, - # we cannot start weight update until all async grad AR works are done. - if self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) == 1: - torch.cuda.synchronize() - if self.megatron_amp_o2: + if self.with_distributed_adam: + # gradients are reduced internally in distributed optimizer + pass + elif self.megatron_amp_o2: + # while async grad allreduce is enabled, bprop will keep moving forward without waiting for + # the finish of async grad AR works. Hence, to guarantee the correctness of grads reduction, + # we cannot start weight update until all async grad AR works are done. + if self.cfg.get('pipeline_model_parallel_size', 1) == 1: + torch.cuda.synchronize() # when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously) if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # main grads are stored in the MainParamsOptimizer wrapper diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index a693fc6fa993..2e52be81ce34 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -92,8 +92,10 @@ def configure_ddp(self): Sets find_unused_parameters to False to use activation-checkpoint-recomputation. """ - if hasattr(self.model, 'megatron_amp_o2') and self.model.megatron_amp_o2: - # do not use DDP if using megatron amp O2 + if (hasattr(self.model, 'megatron_amp_o2') and self.model.megatron_amp_o2) or ( + hasattr(self.model, 'with_distributed_adam') and self.model.with_distributed_adam + ): + # do not use DDP if using megatron amp O2 or distributed optimizer self._model = LightningDistributedModule(self.model) else: app_state = AppState() @@ -419,6 +421,12 @@ def __init__( self.hysteresis = hysteresis self._hysteresis_tracker = self.hysteresis + def _unscale_grads_(self, optimizer, *args): + if getattr(optimizer, "_custom_amp_unscale_grads", False): + return optimizer.unscale_grads(*args) + else: + return super()._unscale_grads_(optimizer, *args) + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): retval = None found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py new file mode 100644 index 000000000000..cba3985c6fcb --- /dev/null +++ b/nemo/core/optim/distributed_adam.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam +from apex.transformer import parallel_state + + +# Wrapper class that supports main_grad buffer +# Note: main_grad buffer is used for O2-style optimizations +class MegatronDistributedFusedAdam(DistributedFusedAdam): + def __init__(self, *args, **kwargs): + if 'process_group' not in kwargs and not parallel_state.is_unitialized(): + kwargs['process_group'] = parallel_state.get_data_parallel_group() + super().__init__(*args, **kwargs) + + def _make_post_backward_hook(self, param, param_group_id, param_id): + def hook(*unused): + with self._lock: + need_to_initialize = 'fragments' not in self.state[param] + if need_to_initialize: + self._init_param_state(param, param_group_id, param_id) + if self.greedy_grad_copy and not getattr(param, '_disable_greedy_grad_copy', False): + self._grad_copy(param) + if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False): + self._try_start_bucket_grad_sync( + params=[param], ignore_last_bucket=need_to_initialize, + ) + + return hook + + def try_grad_sync(self, params): + params = list(params) + for p in params: + self._grad_copy(p) + self._try_start_bucket_grad_sync(params=params) + + def zero_grad(self, *args, **kwargs): + super().zero_grad(*args, **kwargs) + if self.contiguous_grad_buffer: + for param in self.parameters(): + param.main_grad = self.grad_buffer_view(param) + if param.dtype == param.main_grad.dtype and param.is_cuda: + param.grad = param.main_grad diff --git a/nemo/core/optim/optimizers.py b/nemo/core/optim/optimizers.py index 97187358ffec..fd9f1b4672ea 100644 --- a/nemo/core/optim/optimizers.py +++ b/nemo/core/optim/optimizers.py @@ -57,22 +57,12 @@ HAVE_APEX_DISTRIBUTED_ADAM = False if HAVE_APEX: try: - - # Try importing Apex distributed Adam optimizer - from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam - import fused_adam_cuda, distributed_adam_cuda # Required kernels + # Try importing wrapper for Apex distributed Adam optimizer + from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam HAVE_APEX_DISTRIBUTED_ADAM = True - # Wrapper class that supports main_grad buffer - # Note: main_grad buffer is used for O2-style optimizations - class MegatronDistributedFusedAdam(DistributedFusedAdam): - def _init_param_state(self, param, param_group_id, param_id): - super()._init_param_state(param, param_group_id, param_id) - param.main_grad = self.grad_buffer_view(param) - AVAILABLE_OPTIMIZERS['distributed_fused_adam'] = MegatronDistributedFusedAdam - except (ImportError, ModuleNotFoundError): logging.warning("Could not import distributed_fused_adam optimizer from Apex")