diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index a7b1b9521e3c..ee24d849dcff 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -532,9 +532,11 @@ def configure_optimizers(self): if self.with_distributed_adam: # Initialize param buckets if explicitly provided - if hasattr(self, 'distributed_adam_buckets'): + if getattr(self, 'distributed_adam_buckets', None): for bucket in self.distributed_adam_buckets: self._optimizer.init_params_bucket(bucket) + self._optimizer.init_params_bucket(self.parameters()) + if hasattr(self, 'distributed_adam_buckets'): del self.distributed_adam_buckets # Make sure all params are initialized so main grads are diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 1d0e08abe6dd..dbd206e6b2a2 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -231,11 +231,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), ) else: - self.model = build_model( - model_provider_func=self.model_provider_func, - wrap_with_ddp=False, - virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), - ) + fp8_enabled = cfg.get('fp8', False) + fp8_recipe = None + if fp8_enabled and HAVE_TE: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=0, interval=1, fp8_format=transformer_engine.common.recipe.Format.E4M3 + ) + with transformer_engine.pytorch.fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + self.model = build_model( + model_provider_func=self.model_provider_func, + wrap_with_ddp=False, + virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + ) # if we're not using interleaved, then self.model is a module. if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None: @@ -295,8 +302,11 @@ def get_inference_config(self): def model_provider_func(self, pre_process, post_process): """Model depends on pipeline paralellism.""" if self.mcore_gpt: + from megatron.core.models.gpt.gpt_decoder_spec import get_gpt_decoder_spec + model = MCoreGPTModel( config=self.transformer_config, + spec=get_gpt_decoder_spec(), vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), max_sequence_length=self.cfg.get('encoder_seq_length', 512), pre_process=pre_process, @@ -458,12 +468,6 @@ def configure_optimizers(self): [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] ) buckets.reverse() - used_params = set() - for bucket in buckets: - used_params.update(bucket) - remaining_params = [p for p in self.parameters() if p not in used_params] - if remaining_params: - buckets.append(remaining_params) self.distributed_adam_buckets = buckets return super().configure_optimizers() diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 62bba769f652..5c05010da7e3 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -14,10 +14,14 @@ import collections import itertools -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Dict, Iterable, Optional, Union import torch -from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam, _disable_pre_forward_hook +from apex.contrib.optimizers.distributed_fused_adam import ( + DistributedFusedAdam, + _disable_pre_forward_hook, + _multi_tensor_copy, +) from megatron.core import parallel_state from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace from megatron.core.dist_checkpointing.mapping import ShardedTensor @@ -27,6 +31,17 @@ optim_state_to_sharding_state, ) +# Check if Transformer Engine has FP8 tensor class +HAVE_TE_FP8TENSOR = False +try: + from transformer_engine.pytorch import Float8Tensor + from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 + from transformer_engine.pytorch.fp8 import get_fp8_te_dtype + + HAVE_TE_FP8TENSOR = True +except (ImportError, ModuleNotFoundError): + pass + def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: if isinstance(dtype, torch.dtype): @@ -60,6 +75,10 @@ def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: return dtype +def _is_fp8_tensor(tensor: torch.Tensor) -> bool: + return HAVE_TE_FP8TENSOR and isinstance(tensor, Float8Tensor) + + class MegatronDistributedFusedAdam(DistributedFusedAdam): """Wrapper class that supports NeMo-Megatron optimizations @@ -114,6 +133,10 @@ def __init__( fp32_params, grad_sync_dtype=torch.float32, ) + def _broadcast_params(self) -> None: + # Assume params have already been synchronized + pass + def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int,) -> Callable: def hook(*unused): if getattr(param, '_pre_forward_hook_is_enabled', False): @@ -137,6 +160,105 @@ def hook(*unused): return hook + def init_params( + self, + params: Optional[Iterable[torch.nn.Parameter]] = None, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for parameters + + Initializes FP8 and non-FP8 params separately. + + """ + + # Default cases + if params is None: + params = self.parameters() + elif isinstance(params, torch.Tensor): + params = [params] + + # Ignore parameters that have already been initialized + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Initialize FP8 and non-FP8 tensors separately + if any(_is_fp8_tensor(param) for param in params): + super().init_params( + filter(_is_fp8_tensor, params), param_sync_dtype=torch.uint8, **kwargs, + ) + super().init_params( + params, param_sync_dtype=param_sync_dtype, **kwargs, + ) + + def init_params_bucket( + self, params: Iterable[torch.nn.Parameter], param_sync_dtype: Optional[torch.dtype] = None, **kwargs, + ) -> None: + """Initialize optimizer state for parameters in one effective bucket + + If any FP8 params are detected, all non-FP8 params are removed + from the bucket and their overlapped grad syncs are disabled. + This assumes that weight matrices are FP8 params and that + non-FP8 params are small (e.g. biases and layer norm params). + + """ + + # Ignore parameters that have already been initialized + if isinstance(params, torch.Tensor): + params = [params] + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Ignore non-FP8 params if there are any FP8 params + if any(_is_fp8_tensor(param) for param in params): + for param in params: + if not _is_fp8_tensor(param): + param._disable_overlap_grad_sync = True + params = filter(_is_fp8_tensor, params) + param_sync_dtype = torch.uint8 + + # Initialize parameter buckets + super().init_params_bucket( + params, param_sync_dtype=param_sync_dtype, **kwargs, + ) + + def _init_param_state( + self, + param: torch.nn.Parameter, + param_group_id: int, + param_id: int, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for a parameter + + Initializing the master weights requires slicing a flattened + view of the param. FP8 tensors do not handle these operations + gracefully, so we hack around it by explicitly casting to + FP32. + + """ + + # Initialize non-FP8 params as usual + if not _is_fp8_tensor(param): + super()._init_param_state( + param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, + ) + + # Return immediately if already initialized + if "fragments" in self.state[param]: + return + + # Initialize with FP32 copy of param + fp32_param = param.float() + super()._init_param_state( + fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, + ) + self.state[param].update(self.state[fp32_param]) + del self.state[fp32_param] + def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None: def is_grad_copy_enabled(param: torch.nn.Parameter) -> bool: return not getattr(param, '_disable_greedy_grad_copy', False) and not getattr( @@ -183,6 +305,171 @@ def grad_norm( # Use cached grad norm return super().grad_norm() + @torch.no_grad() + def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.ParameterFragment],) -> None: + """Update parameter fragments with values from parameter buckets + + For FP8 params, values are copied directly into the FP8 data + buffer. + + """ + + # Figure out corresponding positions in param buckets and params + buffers_in = [] + buffers_out = [] + fragments = list(fragments) + for fragment in fragments: + + # Check if fragment needs to be updated + bucket_id = fragment.bucket_id + bucket_start, bucket_end = fragment.bucket_range + param_start, param_end = fragment.param_range + if param_end <= param_start or bucket_id not in self._params_buckets: + continue + + # Corresponding positions in bucket and param + state_bucket = self.state["buckets"][bucket_id] + param_bucket = self._params_buckets[bucket_id] + param = self.parameter(fragment) + buffer_in = param_bucket.params_bucket[bucket_start:bucket_end] + if _is_fp8_tensor(param): + # Copy into FP8 params's data buffer + assert ( + param_bucket.params_bucket.dtype == torch.uint8 + ), "Expected FP8 params to perform param sync in UINT8" + buffer_out = param._data.view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + elif torch.is_floating_point(buffer_in) and torch.is_floating_point(param): + # Cast between floating-point dtypes + buffer_out = param.detach().view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + buffer_out = param.detach().view(-1)[param_start:param_end] + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Copy data from parameter buckets to parameters + _multi_tensor_copy( + buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Precompute transposes + for fragment in fragments: + param = self.parameter(fragment) + if _is_fp8_tensor(param): + param._transpose = None + for fragment in fragments: + param = self.parameter(fragment) + if _is_fp8_tensor(param): + param.transpose() + + @torch.no_grad() + def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket],) -> None: + """Make sure local shards of parameters are in expected datatypes + + For FP8 params, FP32 values are cast into FP8 using per-param + scaling factors and per-param amaxes are computed and reduced. + + """ + + # Just call base class function if there are no FP8 tensors + num_fp8_params = sum(1 for param in self.parameters() if _is_fp8_tensor(param)) + if num_fp8_params == 0: + super()._check_params_shard_dtypes(params_buckets) + return + + # FP8 scaling factors + amaxes = [] + scales = [] + scale_invs = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device,) + i = -1 + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + i += 1 + fp8_meta = param.fp8_meta_view["scaling_fwd"] + fp8_meta_index = param.gemm_index + amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) + scales.append(fp8_meta.scale[fp8_meta_index].view(1)) + param._scale_inv_cache = scale_invs[i] + + # Update cached scale-inverses + scale_inv_views = [scale_invs[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + scales, scale_inv_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.reciprocal(scale_invs, out=scale_invs) + + # Cast local data to FP8 + fp8_params_shards = dict() + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + + # FP8 metadata + fp8_meta = param.fp8_meta_view["scaling_fwd"] + fp8_meta_index = param.gemm_index + fp8_dtype = get_fp8_te_dtype(param.fp8_meta_view["recipe"], fprop_tensor=True,) + + # Iterate through fragments with local data + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) + + # Get bucket containing fragment + bucket_id = fragment.bucket_id + if bucket_id not in params_buckets: + continue + state_bucket = self.state["buckets"][bucket_id] + param_bucket = params_buckets[bucket_id] + if state_bucket.param_sync_dtype != torch.uint8: + continue + + # Allocate FP8 buffer if needed + if bucket_id not in fp8_params_shards: + fp8_params_shards[bucket_id] = torch.empty_like(param_bucket.params_shard, dtype=torch.uint8,) + + # FP8 cast and amax + fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) + fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) + cast_to_fp8( + fp32_fragment, fp8_meta, fp8_meta_index, fp8_dtype, out=fp8_fragment, + ) + + # Update param shards with FP8 buffers + for bucket_id, params_shard in fp8_params_shards.items(): + params_buckets[bucket_id].params_shard = params_shard + + # Reduce amaxes + packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device,) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.distributed.all_reduce( + packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + ) + _multi_tensor_copy( + packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Handle any remaining dtype conversions + super()._check_params_shard_dtypes(params_buckets) + def sharded_state_dict(self, model_sharded_state_dict): optimizer_state_dict = self.state_dict() diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index 922412f1e8a6..149dff44620a 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -162,7 +162,7 @@ class MainParamsOptimizerWrapper(torch.optim.Optimizer): Arguments: optimizer: base optimizer such as Adam or SGD. fp32_grad_accum: to enable the use of fp32 in gradient accumulation and allreduce. - contiguous_grad_bucket: to enable allocating the master gradients in the + contiguous_grad_bucket: to enable allocating the master gradients in the contiguous memory space to reduce memory fragmentation. async_grad_allreduce: enable asynchronous gradient allreduce that is executed along with the training step backprop.