Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
git checkout 27cbe46714a50c43ed290f1b1472db8d2780c55c && \
pip install .

# Apex bugfix for PyTorch 23.11 container: https://github.com/NVIDIA/apex/pull/1760
# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 && \
git checkout b496d85fb88a801d8e680872a12822de310951fd && \
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./

# Transformer Engine 1.2.0
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ To install Apex, run

git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227
git checkout b496d85fb88a801d8e680872a12822de310951fd
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./

It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def configure_optimizers(self):
if self.with_distributed_adam:

# Initialize param buckets if explicitly provided
if getattr(self, 'distributed_adam_buckets', None):
if getattr(self, 'distributed_adam_buckets', None) is not None:
for bucket in self.distributed_adam_buckets:
self._optimizer.init_params_bucket(bucket)
self._optimizer.init_params_bucket(self.parameters())
Expand Down
40 changes: 17 additions & 23 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.neural_types import ChannelType, NeuralType
from nemo.utils import logging
from nemo.utils.te_utils import is_float8tensor

try:
import apex.transformer.pipeline_parallel.utils
Expand Down Expand Up @@ -483,8 +484,18 @@ def configure_optimizers(self):
param._disable_overlap_grad_sync = True

# Initialize parameter buckets for overlapped grad and param syncs
# Note: Params with disabled overlapping are put in the
# last param bucket
# Note: Params with disabled overlapping and params in the
# first layer are put together in a bucket. If FP8 tensors
# are detected, those are also put in the first layer's
# bucket.
def make_parameter_bucket(module: torch.nn.Module) -> List[torch.nn.Parameter]:
bucket = [
param for param in module.parameters() if not getattr(param, '_disable_overlap_grad_sync', False)
]
if any(is_float8tensor(param) for param in bucket):
bucket = list(filter(is_float8tensor, bucket))
return bucket

buckets = []
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None:
# Initialize a bucket for each virtual pipeline stage
Expand All @@ -493,35 +504,18 @@ def configure_optimizers(self):
module = module.module
stage_bucket = []
layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers
for layer in layers:
stage_bucket.extend(
p
for p in layer.parameters()
if not getattr(p, '_disable_overlap_grad_sync', False) and p.requires_grad
)
buckets.append(stage_bucket)
buckets.extend(make_parameter_bucket(layer) for layer in layers)
else:
# Initialize a bucket for each Transformer layer
modules = self.model if isinstance(self.model, list) else [self.model]
for module in modules:
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers
for layer in layers:
buckets.append(
[
p
for p in layer.parameters()
if not getattr(p, '_disable_overlap_grad_sync', False) and p.requires_grad
]
)
buckets.extend(make_parameter_bucket(layer) for layer in layers)
buckets.reverse()
used_params = set()
for bucket in buckets:
used_params.update(bucket)
remaining_params = [p for p in self.parameters() if p not in used_params and p.requires_grad]
if remaining_params:
buckets.append(remaining_params)
used_params = set(itertools.chain.from_iterable(buckets))
buckets[-1].extend(p for p in self.parameters() if p not in used_params)
self.distributed_adam_buckets = buckets

return super().configure_optimizers()
Expand Down
Loading