Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
62d059e
Avoid storing extra copy of params in dist Adam optimizer
timmoon10 Aug 23, 2022
a62a368
Add support for dist Adam in GPT-3 without O2-level AMP
timmoon10 Aug 23, 2022
59a7859
Add support for dist Adam in Megatron-LM models
timmoon10 Aug 25, 2022
fa3049a
Merge branch 'main' into dist-adam-nlp-models
timmoon10 Aug 25, 2022
36109c1
Debug dist Adam support without Megatron AMP O2
timmoon10 Aug 26, 2022
23bce24
Merge branch 'main' into dist-adam-nlp-models
timmoon10 Aug 26, 2022
b6b509e
Merge branch 'main' into dist-adam-nlp-models
timmoon10 Aug 29, 2022
1bede48
Merge branch 'main' into dist-adam-nlp-models
timmoon10 Sep 7, 2022
065a89b
Add support for overlapped grad sync with pipeline parallelism in GPT-3
timmoon10 Sep 7, 2022
1ec40b1
Debug dist Adam support for T5
timmoon10 Sep 8, 2022
8b46a9b
Merge branch 'main' into dist-adam-nlp-models
timmoon10 Sep 8, 2022
7943ebc
Merge branch 'dist-adam-nlp-models' into dist-adam-pipeline-parallel-…
timmoon10 Sep 8, 2022
e06d34a
Add support for overlapped grad sync with pipeline parallelism in T5
timmoon10 Sep 9, 2022
d528a89
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Sep 15, 2022
811b59c
Update Apex commits in Dockerfile and Jenkinsfile
timmoon10 Sep 20, 2022
ebd98c4
Merge commit 'e3ac280a861fdda5889f5ded88508ceb259f2278' into dist-ada…
timmoon10 Sep 27, 2022
b2a61ad
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Sep 27, 2022
7c3551e
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Oct 4, 2022
39b3a88
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Oct 11, 2022
8da1ac5
Support distributed Adam in Megatron grad scaler class.
timmoon10 Oct 14, 2022
62acf17
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Oct 14, 2022
c3692af
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Oct 14, 2022
5da9e42
Update dist Adam to accommodate changes in GPT model
timmoon10 Oct 19, 2022
4ef0255
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Oct 19, 2022
a088304
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2022
6706b09
Minor tweaks to dist Adam integration
timmoon10 Oct 20, 2022
6ed59e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2022
aed0e00
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
timmoon10 Oct 20, 2022
0bfb2db
Remove error when dist Adam and interleaved pipeline parallelism are …
timmoon10 Oct 20, 2022
190f992
Merge branch 'main' into dist-adam-pipeline-parallel-async-grad-reduc…
ericharper Oct 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ RUN apt-get update && \
python-dev ffmpeg && \
rm -rf /var/lib/apt/lists/*

# FIXME a workaround to update apex. Remove when base image is updated
WORKDIR /tmp/
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
git checkout 3c19f1061879394f28272a99a7ea26d58f72dace && \
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./
git checkout 2b0e8371113fe70758f1964c40bf7dbe304fd9e6 && \
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./

# uninstall stuff from base container
RUN pip uninstall -y sacrebleu torchtext
Expand Down
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def main(cfg) -> None:
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam'
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce
Expand All @@ -52,7 +53,7 @@ def main(cfg) -> None:
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2:
if megatron_amp_o2 and not with_distributed_adam:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os
import re
from typing import Optional
from typing import Any, Dict, Optional, Union

import torch
from omegaconf import open_dict
Expand Down Expand Up @@ -53,9 +53,13 @@ class MegatronBaseModel(NLPModel):
1. Initialize the model parallel for nemo given the model parallel parameters.
2. Turn on all the nvidia optimizations.
3. If `cfg.tokenizer` is available, it loads the tokenizer and pad the vocab to the correct size for tensor model parallelism.
4. It help to run `configure_gradient_clipping`, if `grad_clip_pl_default` is set True, it uses the pytorch lightning default
gradient clipping. Or if `megatron_amp_o2` is set True, it uses the parameters from optimizer to clip the gradients.
Otherwise, it uses the parameters calculated in the `setup_optimizer_param_groups` method.
4. If using distributed optimizer, configure to be compatible with
O2-level optimizations and/or model parallelism.
5. Perform gradient clipping: `grad_clip_pl_default` triggers the
PyTorch Lightning default implementation, `with_distributed_adam`
triggers the distributed optimizer's implementation,
`megatron_amp_o2` triggers gradient clipping on the main grads,
and otherwise gradient clipping is performed on the model grads.
"""

def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
Expand All @@ -73,6 +77,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):

self._validate_config()

self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam'

# used in NVIDIA NGC PyTorch containers
self._enable_nvidia_optimizations()

Expand Down Expand Up @@ -220,7 +226,7 @@ def configure_gradient_clipping(self, *args, **kwargs):
# use the default behavior
return super().configure_gradient_clipping(*args, **kwargs)

if hasattr(self, 'with_distributed_adam') and self.with_distributed_adam:
if self.with_distributed_adam:
grad_norm = clip_grad_norm_distributed_optimizer(self._optimizer, clip_val)
else:
if self.megatron_amp_o2:
Expand Down Expand Up @@ -256,6 +262,20 @@ def allreduce_gradients(self):
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

def reduce_overlap_gradients(self):
"""Reduce grads if overlapped grad sync is enabled

Used for pipeline parallelism with the distributed Adam
optimizer. In the first pipeline stage, the grad sync is
overlapped with the final backward pass. In other pipeline
stages, the grad sync is deferred until the bubble overhead.

"""
if self.with_distributed_adam:
self._optimizer.try_grad_sync(
p for p in self._optimizer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
)

def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None:
super().on_train_batch_end(outputs, batch, batch_idx)

Expand Down Expand Up @@ -294,15 +314,37 @@ def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[in
# accumulated gradient updates.
grad_scaler.optimizer_update_skipped = None

def setup_optimization(
self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None,
):
optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy()
if self.with_distributed_adam:

# Allocate grads since we are storing between microbatches
optim_kwargs['contiguous_grad_buffer'] = True

if self.megatron_amp_o2:
# Match param allgather with model dtype
if hasattr(self, 'autocast_dtype'):
optim_kwargs['param_sync_dtype'] = self.autocast_dtype
if self.autocast_dtype == torch.float:
optim_kwargs['store_params'] = False
elif self.autocast_dtype == torch.float16:
optim_kwargs['store_params'] = True
elif self.autocast_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
# Assume FP32 params, so no need to store main params
optim_kwargs['store_params'] = False

return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs)

def configure_optimizers(self):
self.setup_optimization()

# Wrap the baseline optimizer with the optimizer class with master parameters
if (
self.megatron_amp_o2
and not (hasattr(self, 'with_distributed_adam') and self.with_distributed_adam)
and self._optimizer is not None
):
if self.megatron_amp_o2 and not self.with_distributed_adam and self._optimizer is not None:
if self.cfg.precision == 'bf16':
fp32_grad_accum = True
contiguous_grad_bucket = True
Expand Down Expand Up @@ -347,6 +389,16 @@ def configure_optimizers(self):
optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl
)

# Configure distributed optimizer
if self.with_distributed_adam:

# Initialize params so that main grads are available
# Note: Consolidate grads without overlap
self._optimizer.init_params(
p for p in self.parameters() if getattr(p, '_disable_overlap_grad_sync', False)
)
self._optimizer.init_params(self.parameters())

if self._scheduler is None:
return self._optimizer
else:
Expand Down
98 changes: 57 additions & 41 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Union
import itertools
from typing import Any, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -82,16 +83,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._validate_trainer()

self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False)
self.with_distributed_adam = cfg.optim.get('name') == 'distributed_fused_adam'

if not self.megatron_amp_o2 and self.cfg.get('virtual_pipeline_model_parallel_size', None):
raise ValueError('Virtual pipeline model parallel is only supported when using megatron_amp_O2')

if self.with_distributed_adam and not self.megatron_amp_o2:
raise ValueError(
"Distributed optimizers require O2. Please set megatron_amp_O2 to True in the model config."
)

# build_model returns a list of modules which are used for interleaved pipeline parallelism
self.model = build_model(
model_provider_func=self.model_provider_func,
Expand Down Expand Up @@ -186,15 +181,40 @@ def setup_optimizer_param_groups(self):
else:
self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model)

def setup_optimization(
self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None,
):
optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy()
def configure_optimizers(self):

if self.with_distributed_adam:
optim_kwargs['process_group'] = parallel_state.get_data_parallel_group()
optim_kwargs['param_sync_dtype'] = self.autocast_dtype
optim_kwargs['contiguous_grad_buffer'] = True
return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs)

# Disable overlapped grad sync for embedding grad when
# pipeline parallelism is enabled
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[0] # only the first virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
param = module.word_embeddings_weight()
param._disable_greedy_grad_copy = not self.megatron_amp_o2
param._disable_overlap_grad_sync = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[-1] # only the last virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
param = module.word_embeddings_weight()
param._disable_greedy_grad_copy = not self.megatron_amp_o2
param._disable_overlap_grad_sync = True

# Disable overlapped grad sync for layer norm grads when
# sequence parallelism is enabled
for param in self.parameters():
if getattr(param, 'sequence_parallel_enabled', False):
param._disable_greedy_grad_copy = not self.megatron_amp_o2
param._disable_overlap_grad_sync = True

return super().configure_optimizers()

def forward(self, tokens, text_position_ids, attention_mask, labels):
output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels)
Expand Down Expand Up @@ -236,16 +256,20 @@ def training_step(self, batch, batch_idx):

tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# determine if we can use async grad all reduce
custom_sync_context_handler = None
if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False):
if self.with_distributed_adam:
# handle asynchronous grad reduction
if self.with_distributed_adam:
if self.megatron_amp_o2:
# copy grads to main grad
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True)
else:
custom_sync_context_handler = self._optimizer.no_sync
# keep grad tensors around
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False)
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None
if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False):
custom_sync_context_handler = self._optimizer.no_sync
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None

# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
Expand Down Expand Up @@ -277,8 +301,11 @@ def training_step(self, batch, batch_idx):
self.allreduce_sequence_parallel_gradients()

if self.with_distributed_adam:
# gradients are reduced internally in distributed optimizer
pass
# launch grad reductions
# Note: grads in first pipeline stage have already been
# reduced
if not parallel_state.is_pipeline_first_stage():
self.reduce_overlap_gradients()
elif self.megatron_amp_o2:
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False):
Expand Down Expand Up @@ -753,23 +780,6 @@ def setup_test_data(self, cfg):
)
self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples)

def configure_optimizers(self):
retval = super().configure_optimizers()

if self.with_distributed_adam:

# Initialize params in reverse order
# Note: Estimate order in which grads are generated in
# backward pass
self._optimizer.init_params(reversed(list(self.parameters())))

# Overlapped communication interferes with grad reductions
# for pipeline parallelism and sequence parallelism
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False):
self._optimizer.overlap_grad_sync = False

return retval

def generate(
self,
inputs: Union[List[str], torch.Tensor, List[dict]],
Expand Down Expand Up @@ -878,3 +888,9 @@ def on_load_checkpoint(self, checkpoint) -> None:
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

def parameters(self):
if isinstance(self.model, list):
return itertools.chain.from_iterable(module.parameters() for module in self.model)
else:
return self.model.parameters()
Loading