Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,11 @@ def configure_optimizers(self):
if self.with_distributed_adam:

# Initialize param buckets if explicitly provided
if hasattr(self, 'distributed_adam_buckets'):
if getattr(self, 'distributed_adam_buckets', None):
for bucket in self.distributed_adam_buckets:
self._optimizer.init_params_bucket(bucket)
self._optimizer.init_params_bucket(self.parameters())
if hasattr(self, 'distributed_adam_buckets'):
del self.distributed_adam_buckets

# Make sure all params are initialized so main grads are
Expand Down
26 changes: 15 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
)
else:
self.model = build_model(
model_provider_func=self.model_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
)
fp8_enabled = cfg.get('fp8', False)
fp8_recipe = None
if fp8_enabled and HAVE_TE:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=0, interval=1, fp8_format=transformer_engine.common.recipe.Format.E4M3
)
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
self.model = build_model(
model_provider_func=self.model_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
)

# if we're not using interleaved, then self.model is a module.
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None:
Expand Down Expand Up @@ -295,8 +302,11 @@ def get_inference_config(self):
def model_provider_func(self, pre_process, post_process):
"""Model depends on pipeline paralellism."""
if self.mcore_gpt:
from megatron.core.models.gpt.gpt_decoder_spec import get_gpt_decoder_spec

model = MCoreGPTModel(
config=self.transformer_config,
spec=get_gpt_decoder_spec(),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
pre_process=pre_process,
Expand Down Expand Up @@ -458,12 +468,6 @@ def configure_optimizers(self):
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
)
buckets.reverse()
used_params = set()
for bucket in buckets:
used_params.update(bucket)
remaining_params = [p for p in self.parameters() if p not in used_params]
if remaining_params:
buckets.append(remaining_params)
self.distributed_adam_buckets = buckets

return super().configure_optimizers()
Expand Down
291 changes: 289 additions & 2 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

import collections
import itertools
from typing import Callable, Iterable, Optional, Union
from typing import Callable, Dict, Iterable, Optional, Union

import torch
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam, _disable_pre_forward_hook
from apex.contrib.optimizers.distributed_fused_adam import (
DistributedFusedAdam,
_disable_pre_forward_hook,
_multi_tensor_copy,
)
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import ShardedTensor
Expand All @@ -27,6 +31,17 @@
optim_state_to_sharding_state,
)

# Check if Transformer Engine has FP8 tensor class
HAVE_TE_FP8TENSOR = False
try:
from transformer_engine.pytorch import Float8Tensor
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8

Check notice

Code scanning / CodeQL

Empty except

'except' clause does nothing but pass and there is no explanatory comment.
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype

HAVE_TE_FP8TENSOR = True
except (ImportError, ModuleNotFoundError):
pass


def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
if isinstance(dtype, torch.dtype):
Expand Down Expand Up @@ -60,6 +75,10 @@ def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
return dtype


def _is_fp8_tensor(tensor: torch.Tensor) -> bool:
return HAVE_TE_FP8TENSOR and isinstance(tensor, Float8Tensor)


class MegatronDistributedFusedAdam(DistributedFusedAdam):
"""Wrapper class that supports NeMo-Megatron optimizations

Expand Down Expand Up @@ -114,6 +133,10 @@ def __init__(
fp32_params, grad_sync_dtype=torch.float32,
)

def _broadcast_params(self) -> None:
# Assume params have already been synchronized
pass

def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int,) -> Callable:
def hook(*unused):
if getattr(param, '_pre_forward_hook_is_enabled', False):
Expand All @@ -137,6 +160,105 @@ def hook(*unused):

return hook

def init_params(
self,
params: Optional[Iterable[torch.nn.Parameter]] = None,
param_sync_dtype: Optional[torch.dtype] = None,
**kwargs,
) -> None:
"""Initialize optimizer state for parameters

Initializes FP8 and non-FP8 params separately.

"""

# Default cases
if params is None:
params = self.parameters()
elif isinstance(params, torch.Tensor):
params = [params]

# Ignore parameters that have already been initialized
params = [param for param in params if "fragments" not in self.state[param]]
if not params:
return

# Initialize FP8 and non-FP8 tensors separately
if any(_is_fp8_tensor(param) for param in params):
super().init_params(
filter(_is_fp8_tensor, params), param_sync_dtype=torch.uint8, **kwargs,
)
super().init_params(
params, param_sync_dtype=param_sync_dtype, **kwargs,
)

def init_params_bucket(
self, params: Iterable[torch.nn.Parameter], param_sync_dtype: Optional[torch.dtype] = None, **kwargs,
) -> None:
"""Initialize optimizer state for parameters in one effective bucket

If any FP8 params are detected, all non-FP8 params are removed
from the bucket and their overlapped grad syncs are disabled.
This assumes that weight matrices are FP8 params and that
non-FP8 params are small (e.g. biases and layer norm params).

"""

# Ignore parameters that have already been initialized
if isinstance(params, torch.Tensor):
params = [params]
params = [param for param in params if "fragments" not in self.state[param]]
if not params:
return

# Ignore non-FP8 params if there are any FP8 params
if any(_is_fp8_tensor(param) for param in params):
for param in params:
if not _is_fp8_tensor(param):
param._disable_overlap_grad_sync = True
params = filter(_is_fp8_tensor, params)
param_sync_dtype = torch.uint8

# Initialize parameter buckets
super().init_params_bucket(
params, param_sync_dtype=param_sync_dtype, **kwargs,
)

def _init_param_state(
self,
param: torch.nn.Parameter,
param_group_id: int,
param_id: int,
param_sync_dtype: Optional[torch.dtype] = None,
**kwargs,
) -> None:
"""Initialize optimizer state for a parameter

Initializing the master weights requires slicing a flattened
view of the param. FP8 tensors do not handle these operations
gracefully, so we hack around it by explicitly casting to
FP32.

"""

# Initialize non-FP8 params as usual
if not _is_fp8_tensor(param):
super()._init_param_state(
param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs,
)

# Return immediately if already initialized
if "fragments" in self.state[param]:
return

# Initialize with FP32 copy of param
fp32_param = param.float()
super()._init_param_state(
fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs,
)
self.state[param].update(self.state[fp32_param])
del self.state[fp32_param]

def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None:
def is_grad_copy_enabled(param: torch.nn.Parameter) -> bool:
return not getattr(param, '_disable_greedy_grad_copy', False) and not getattr(
Expand Down Expand Up @@ -183,6 +305,171 @@ def grad_norm(
# Use cached grad norm
return super().grad_norm()

@torch.no_grad()
def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.ParameterFragment],) -> None:
"""Update parameter fragments with values from parameter buckets

For FP8 params, values are copied directly into the FP8 data
buffer.

"""

# Figure out corresponding positions in param buckets and params
buffers_in = []
buffers_out = []
fragments = list(fragments)
for fragment in fragments:

# Check if fragment needs to be updated
bucket_id = fragment.bucket_id
bucket_start, bucket_end = fragment.bucket_range
param_start, param_end = fragment.param_range
if param_end <= param_start or bucket_id not in self._params_buckets:

Check notice

Code scanning / CodeQL

Unused local variable

Variable state_bucket is not used.
continue

# Corresponding positions in bucket and param
state_bucket = self.state["buckets"][bucket_id]
param_bucket = self._params_buckets[bucket_id]
param = self.parameter(fragment)
buffer_in = param_bucket.params_bucket[bucket_start:bucket_end]
if _is_fp8_tensor(param):
# Copy into FP8 params's data buffer
assert (
param_bucket.params_bucket.dtype == torch.uint8
), "Expected FP8 params to perform param sync in UINT8"
buffer_out = param._data.view(-1)[param_start:param_end]
buffers_in.append(buffer_in)
buffers_out.append(buffer_out)
elif torch.is_floating_point(buffer_in) and torch.is_floating_point(param):
# Cast between floating-point dtypes
buffer_out = param.detach().view(-1)[param_start:param_end]
buffers_in.append(buffer_in)
buffers_out.append(buffer_out)
else:
# Copy most significant bytes for non-floating-point
# dtypes
# Note: Assume dtypes are little-endian
buffer_out = param.detach().view(-1)[param_start:param_end]
in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8)
out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8)
copy_size = min(in_bytes.size(-1), out_bytes.size(-1))
buffers_in.append(in_bytes[..., -copy_size:])
buffers_out.append(out_bytes[..., -copy_size:])
if copy_size < out_bytes.size(-1):
out_bytes[..., :-copy_size].zero_()

# Copy data from parameter buckets to parameters
_multi_tensor_copy(
buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf,
)

# Precompute transposes
for fragment in fragments:
param = self.parameter(fragment)
if _is_fp8_tensor(param):
param._transpose = None
for fragment in fragments:
param = self.parameter(fragment)
if _is_fp8_tensor(param):
param.transpose()

@torch.no_grad()
def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket],) -> None:
"""Make sure local shards of parameters are in expected datatypes

For FP8 params, FP32 values are cast into FP8 using per-param
scaling factors and per-param amaxes are computed and reduced.

"""

# Just call base class function if there are no FP8 tensors
num_fp8_params = sum(1 for param in self.parameters() if _is_fp8_tensor(param))
if num_fp8_params == 0:
super()._check_params_shard_dtypes(params_buckets)
return

# FP8 scaling factors
amaxes = []
scales = []
scale_invs = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device,)
i = -1
for param in self.parameters():
if not _is_fp8_tensor(param):
continue
i += 1
fp8_meta = param.fp8_meta_view["scaling_fwd"]
fp8_meta_index = param.gemm_index
amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1))
scales.append(fp8_meta.scale[fp8_meta_index].view(1))
param._scale_inv_cache = scale_invs[i]

# Update cached scale-inverses
scale_inv_views = [scale_invs[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
scales, scale_inv_views, dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.reciprocal(scale_invs, out=scale_invs)

# Cast local data to FP8
fp8_params_shards = dict()
for param in self.parameters():
if not _is_fp8_tensor(param):
continue

# FP8 metadata
fp8_meta = param.fp8_meta_view["scaling_fwd"]
fp8_meta_index = param.gemm_index
fp8_dtype = get_fp8_te_dtype(param.fp8_meta_view["recipe"], fprop_tensor=True,)

# Iterate through fragments with local data
for fragment in self.state[param]["fragments"]:
if not fragment.in_local_shard:
continue
shard_start, shard_end = fragment.shard_range
if shard_end <= shard_start:
continue
shard_range = slice(shard_start, shard_end)

# Get bucket containing fragment
bucket_id = fragment.bucket_id
if bucket_id not in params_buckets:
continue
state_bucket = self.state["buckets"][bucket_id]
param_bucket = params_buckets[bucket_id]
if state_bucket.param_sync_dtype != torch.uint8:
continue

# Allocate FP8 buffer if needed
if bucket_id not in fp8_params_shards:
fp8_params_shards[bucket_id] = torch.empty_like(param_bucket.params_shard, dtype=torch.uint8,)

# FP8 cast and amax
fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1)
fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1)
cast_to_fp8(
fp32_fragment, fp8_meta, fp8_meta_index, fp8_dtype, out=fp8_fragment,
)

# Update param shards with FP8 buffers
for bucket_id, params_shard in fp8_params_shards.items():
params_buckets[bucket_id].params_shard = params_shard

# Reduce amaxes
packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device,)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.distributed.all_reduce(
packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group,
)
_multi_tensor_copy(
packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf,
)

# Handle any remaining dtype conversions
super()._check_params_shard_dtypes(params_buckets)

def sharded_state_dict(self, model_sharded_state_dict):
optimizer_state_dict = self.state_dict()

Expand Down
Loading