From b031db6e5b28809677d6d16807489f840e9bc3f1 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 11 Dec 2023 14:25:49 -0800 Subject: [PATCH 1/6] Add distopt support for FP8 params and BF16 optimizer state Signed-off-by: Tim Moon --- Dockerfile | 6 +- Jenkinsfile | 2 +- README.rst | 6 +- .../language_modeling/megatron_base_model.py | 44 +- .../language_modeling/megatron_gpt_model.py | 21 +- nemo/core/optim/distributed_adam.py | 418 ++++++++++++++++-- nemo/utils/__init__.py | 1 + nemo/utils/dtype.py | 53 +++ 8 files changed, 480 insertions(+), 71 deletions(-) create mode 100644 nemo/utils/dtype.py diff --git a/Dockerfile b/Dockerfile index 5d3311c7cdfd..e24f41159660 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,12 +53,12 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ # Distributed Adam support for multiple dtypes RUN git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout 52e18c894223800cb611682dce27d88050edf1de && \ - pip install install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ + git checkout a2f6683b10fb4c29ab57c9e3d16957db76a8a5ba && \ + pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ cd TransformerEngine && \ - git fetch origin 8eae4ce2b8fdfbbe525fc8bfecb0df5498cc9687 && \ + git fetch origin ff760a9d838ae4617600cccb22131d0359ce0296 && \ git checkout FETCH_HEAD && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . diff --git a/Jenkinsfile b/Jenkinsfile index 12fafac57a67..17785f62cc52 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -61,7 +61,7 @@ pipeline { steps { sh 'git clone https://github.com/NVIDIA/TransformerEngine.git && \ cd TransformerEngine && \ - git fetch origin e6676c53f26f6ef072943c909d136cf2a39c1d90 && \ + git fetch origin ff760a9d838ae4617600cccb22131d0359ce0296 && \ git checkout FETCH_HEAD && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install .' diff --git a/README.rst b/README.rst index fba4aaf04f09..374512e4ceac 100644 --- a/README.rst +++ b/README.rst @@ -295,8 +295,8 @@ To install Apex, run git clone https://github.com/NVIDIA/apex.git cd apex - git checkout 52e18c894223800cb611682dce27d88050edf1de - pip install install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ + git checkout a2f6683b10fb4c29ab57c9e3d16957db76a8a5ba + pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies. @@ -335,7 +335,7 @@ Transformer Engine requires PyTorch to be built with CUDA 11.8. Flash Attention ~~~~~~~~~~~~~~~~~~~~ -Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn `_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation `_. +Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn `_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation `_. .. code-block:: bash diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index ccdd2e8725db..915bb069172e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -39,7 +39,7 @@ from nemo.collections.nlp.parts import utils_funcs from nemo.collections.nlp.parts.nlp_overrides import NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, GradScaler from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler -from nemo.utils import AppState, logging +from nemo.utils import AppState, logging, str_to_dtype from nemo.utils.get_rank import is_global_rank_zero try: @@ -457,19 +457,39 @@ def setup_optimization( self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, ): optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() + + def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any: + """Get keyword argument from config""" + val = None + if val is None and optim_kwargs: + val = optim_kwargs.get(key, None) + if val is None and optim_config: + val = optim_config.get(key, None) + if val is None and self._cfg.optim: + val = self._cfg.optim.get(key, None) + if val is None: + val = default_value + return val + if self.with_distributed_adam: - # Allocate contiguous buffer to avoid extra copies - optim_kwargs['contiguous_grad_buffer'] = True + # Allocate contiguous grad buffer to avoid extra copies + optim_kwargs['contiguous_grad_buffer'] = get_config_arg('contiguous_grad_buffer', True) + if self.megatron_amp_O2 and not optim_kwargs['contiguous_grad_buffer']: + raise ValueError( + "Distributed Adam optimizer requires contiguous param buffer for O2. " + "Either enable contiguous_grad_buffer or disable megatron_amp_O2." + ) - # Make sure optimizer state is in FP32 - optim_dtype = torch.float32 + # Optimizer dtype + optim_dtype = str_to_dtype(get_config_arg('dtype', torch.float32)) optim_kwargs['dtype'] = optim_dtype # Make sure embedding grad reductions are in FP32 - for name, param in self.named_parameters(): - if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name: - param._with_fp32_optimizer = True + if optim_dtype == torch.float32: + for name, param in self.named_parameters(): + if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name: + param._with_fp32_optimizer = True # Match param allgather with model dtype model_dtype = torch.float32 @@ -478,7 +498,9 @@ def setup_optimization( optim_kwargs['param_sync_dtype'] = model_dtype # Determine whether to store master params in optimizer - if optim_dtype == model_dtype: + if self.cfg.get('fp8_params', False): + optim_kwargs['store_params'] = True + elif optim_dtype == model_dtype: optim_kwargs['store_params'] = False elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16: optim_kwargs['store_params'] = False @@ -545,9 +567,11 @@ def configure_optimizers(self): if self.with_distributed_adam: # Initialize param buckets if explicitly provided - if hasattr(self, 'distributed_adam_buckets'): + if getattr(self, 'distributed_adam_buckets', None): for bucket in self.distributed_adam_buckets: self._optimizer.init_params_bucket(bucket) + self._optimizer.init_params_bucket(self.parameters()) + if hasattr(self, 'distributed_adam_buckets'): del self.distributed_adam_buckets # Make sure all params are initialized so main grads are diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index c2e39ea03a3e..224c356f722e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -16,6 +16,7 @@ import os import queue import warnings +from contextlib import nullcontext from dataclasses import fields from functools import partial from typing import Any, Dict, Iterator, List, Optional, Union @@ -234,11 +235,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), ) else: - self.model = build_model( - model_provider_func=self.model_provider_func, - wrap_with_ddp=False, - virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), - ) + build_model_context = nullcontext + if HAVE_TE and self.cfg.get('fp8', False) and self.cfg.get('fp8_params', False): + build_model_context = transformer_engine.pytorch.fp8_model_init + with build_model_context(): + self.model = build_model( + model_provider_func=self.model_provider_func, + wrap_with_ddp=False, + virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + ) # if we're not using interleaved, then self.model is a module. if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None: @@ -472,12 +477,6 @@ def configure_optimizers(self): [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] ) buckets.reverse() - used_params = set() - for bucket in buckets: - used_params.update(bucket) - remaining_params = [p for p in self.parameters() if p not in used_params] - if remaining_params: - buckets.append(remaining_params) self.distributed_adam_buckets = buckets return super().configure_optimizers() diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index d7bc049c1808..dbe31f57b6b4 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -13,54 +13,51 @@ # limitations under the License. import collections -import itertools -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Dict, Iterable, Optional, Union import torch -from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam, _disable_pre_forward_hook +from apex.contrib.optimizers.distributed_fused_adam import ( + DistributedFusedAdam, + _disable_pre_forward_hook, + _multi_tensor_copy, +) from megatron.core import parallel_state from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace from megatron.core.dist_checkpointing.mapping import ShardedTensor from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state +from nemo.utils import logging, str_to_dtype -def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: - if isinstance(dtype, torch.dtype): - return dtype - name = str(dtype).strip().lower() - if name.startswith("torch."): - name = name.replace("torch.", "", 1) - if name.startswith("fp"): - name = name.replace("fp", "float", 1) - dtype = dict( - float32=torch.float32, - float=torch.float32, - float64=torch.float64, - double=torch.float64, - float16=torch.float16, - half=torch.float16, - bfloat16=torch.bfloat16, - bf16=torch.bfloat16, - uint8=torch.uint8, - byte=torch.uint8, - int8=torch.int8, - char=torch.int8, - int16=torch.int16, - short=torch.int16, - int32=torch.int32, - int=torch.int32, - int64=torch.int64, - long=torch.int64, - bool=torch.bool, - )[name] - return dtype +# Check if Transformer Engine has FP8 tensor class +HAVE_TE_FP8TENSOR = False +try: + from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE_FP8TENSOR = True +except (ImportError, ModuleNotFoundError): + # Float8Tensor not found + pass + + +def _is_fp8_tensor(tensor: torch.Tensor) -> bool: + return HAVE_TE_FP8TENSOR and isinstance(tensor, Float8Tensor) class MegatronDistributedFusedAdam(DistributedFusedAdam): - """Wrapper class that supports NeMo-Megatron optimizations + """Adam optimizer with ZeRO algorithm - When O2-style optimizations are enabled, gradients are accumulated - into the main_grad buffer instead of the grad buffer. + Child class of Apex DistributedFusedAdam, with optimizations for + NeMo-Megatron. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts + defining parameter groups. + disable_distributed_parameters (bool, optional): use standard + data-parallel communication instead of ZeRO. + (default: False) + **kwargs: keyword arguments to pass to Apex + DistributedFusedAdam. """ @@ -84,7 +81,7 @@ def __init__( # Make sure dtypes are in right type for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'): if keyword in kwargs: - kwargs[keyword] = _str_to_dtype(kwargs[keyword]) + kwargs[keyword] = str_to_dtype(kwargs[keyword]) # Make sure params are in consistent format (list of param group dicts) param_groups = list(params) @@ -100,17 +97,21 @@ def __init__( fp32_params = [] for param_group in param_groups: fp32_params.extend( - filter(lambda param: getattr(param, '_with_fp32_optimizer', False), param_group['params'],) + filter(lambda param: getattr(param, '_with_fp32_optimizer', False), param_group['params']) ) if fp32_params: - assert self.dtype == torch.float32, ( - 'Param requires FP32 state, ' f'but optimizer is initialized with {dtype}' - ) + assert ( + self.dtype == torch.float32 + ), f'Param requires FP32 state but optimizer is initialized with {self.dtype}' self.init_params_bucket( fp32_params, grad_sync_dtype=torch.float32, ) - def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int,) -> Callable: + def _broadcast_params(self) -> None: + # Assume params have already been synchronized + pass + + def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int) -> Callable: def hook(*unused): if getattr(param, '_pre_forward_hook_is_enabled', False): raise RuntimeError( @@ -133,7 +134,173 @@ def hook(*unused): return hook + def init_params( + self, + params: Optional[Iterable[torch.nn.Parameter]] = None, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for parameters + + Initializes FP8 and non-FP8 params separately. + + """ + + # Default cases + if params is None: + params = self.parameters() + elif isinstance(params, torch.Tensor): + params = [params] + + # Ignore parameters that have already been initialized + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Initialize FP8 and non-FP8 tensors separately + if any(_is_fp8_tensor(param) for param in params): + super().init_params( + filter(_is_fp8_tensor, params), param_sync_dtype=torch.uint8, **kwargs, + ) + super().init_params( + params, param_sync_dtype=param_sync_dtype, **kwargs, + ) + + def init_params_bucket( + self, params: Iterable[torch.nn.Parameter], param_sync_dtype: Optional[torch.dtype] = None, **kwargs, + ) -> None: + """Initialize optimizer state for parameters in one effective bucket + + If any FP8 params are detected, all non-FP8 params are removed + from the bucket and their overlapped grad syncs are disabled. + This assumes that weight matrices are FP8 params and that + non-FP8 params are small (e.g. biases and layer norm params). + + """ + + # Ignore parameters that have already been initialized + if isinstance(params, torch.Tensor): + params = [params] + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Ignore non-FP8 params if there are any FP8 params + if any(_is_fp8_tensor(param) for param in params): + for param in params: + if not _is_fp8_tensor(param): + param._disable_overlap_grad_sync = True + params = filter(_is_fp8_tensor, params) + param_sync_dtype = torch.uint8 + + # Initialize parameter buckets + super().init_params_bucket( + params, param_sync_dtype=param_sync_dtype, **kwargs, + ) + + def _init_param_state( + self, + param: torch.nn.Parameter, + param_group_id: int, + param_id: int, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for a parameter + + Initializing the master weights requires slicing a flattened + view of the param. FP8 tensors do not handle these operations + gracefully, so we hack around it by explicitly casting to + FP32. + + """ + + # Initialize non-FP8 params as usual + if not _is_fp8_tensor(param): + super()._init_param_state( + param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, + ) + + # Return immediately if already initialized + if "fragments" in self.state[param]: + return + + # Initialize with FP32 copy of param + fp32_param = param.float() + super()._init_param_state( + fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, + ) + self.state[param].update(self.state[fp32_param]) + del self.state[fp32_param] + + @torch.no_grad() + def init_param_buffer(self) -> None: + """Allocate contiguous buffers for param buckets + + For FP8 params, the FP8 data buffer is made a view into a + contiguous buffer. + + """ + + # Make sure all params are initialized + self.contiguous_param_buffer = True + self.init_params() + + # Construct param buffers + buffer_sizes = collections.defaultdict(lambda: 0) + for bucket in self.state["buckets"]: + dtypes = bucket.dtypes() + buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes],) + for dtypes, buffer_size in buffer_sizes.items(): + _, _, param_sync_dtype = dtypes + self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device,) + + # Figure out corresponding positions in params and param buffer + params = list(self.parameters()) + param_flat_views = [] + param_buffer_views = [] + for i, param in enumerate(params): + fragment = self.state[param]["fragments"][0] + bucket_id = fragment.bucket_id + bucket = self.state["buckets"][bucket_id] + param_size = param.numel() + bucket_start, _ = fragment.bucket_range + buffer_offset = bucket.contiguous_buffer_offset + buffer_start = buffer_offset + bucket_start + buffer_end = buffer_start + param_size + param_buffer = self._param_buffers[bucket.dtypes()] + param_buffer_view = param_buffer[buffer_start:buffer_end].detach() + if param_buffer_view.device != param.device: + raise RuntimeError( + "Attempted to change a parameter with device={param.device} " + f"into a buffer view with device={param_buffer_view.device}" + ) + if _is_fp8_tensor(param): + param_flat_views.append(param._data.detach().view(-1)) + else: + if param_buffer_view.dtype != param.dtype: + raise RuntimeError( + f"Attempted to change a parameter with dtype={param.dtype} " + f"into a buffer view with dtype={param_buffer_view.dtype}" + ) + param_flat_views.append(param.detach().view(-1)) + param_buffer_views.append(param_buffer_view) + + # Copy values into param buffer + _multi_tensor_copy( + param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Make all params a view into the param buffer + for param, buffer_view in zip(params, param_buffer_views): + if _is_fp8_tensor(param): + param._data.data = buffer_view.view(param.size()) + else: + param.data = buffer_view.view(param.size()) + def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None: + """Attempt to launch gradient synchronization""" + def is_grad_copy_enabled(param: torch.nn.Parameter) -> bool: return not getattr(param, '_disable_greedy_grad_copy', False) and not getattr( param, '_disable_overlap_grad_sync', False @@ -166,7 +333,7 @@ def grad_norm( if force or self._grad_norm is None: # Compute norm of local gradients for distributed optimizer - grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type,) + grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type) if self.redundant_size > 1: grad_norm_sq /= self.redundant_size @@ -179,6 +346,171 @@ def grad_norm( # Use cached grad norm return super().grad_norm() + @torch.no_grad() + def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.ParameterFragment]) -> None: + """Update parameter fragments with values from parameter buckets + + For FP8 params, values are copied directly into the FP8 data + buffer. + + """ + + # Figure out corresponding positions in param buckets and params + buffers_in = [] + buffers_out = [] + fragments = list(fragments) + for fragment in fragments: + + # Check if fragment needs to be updated + bucket_id = fragment.bucket_id + bucket_start, bucket_end = fragment.bucket_range + param_start, param_end = fragment.param_range + if param_end <= param_start or bucket_id not in self._params_buckets: + continue + + # Corresponding positions in bucket and param + param_bucket = self._params_buckets[bucket_id] + param = self.parameter(fragment) + buffer_in = param_bucket.params_bucket[bucket_start:bucket_end] + if _is_fp8_tensor(param): + # Copy into FP8 params's data buffer + assert ( + param_bucket.params_bucket.dtype == torch.uint8 + ), "Expected FP8 params to perform param sync in UINT8" + buffer_out = param._data.view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + elif torch.is_floating_point(buffer_in) and torch.is_floating_point(param): + # Cast between floating-point dtypes + buffer_out = param.detach().view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + buffer_out = param.detach().view(-1)[param_start:param_end] + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Copy data from parameter buckets to parameters + _multi_tensor_copy( + buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Update transpose caches + params = set(self.parameter(fragment) for fragment in fragments) + for param in params: + if _is_fp8_tensor(param): + param.transpose(update_cache=True) + + @torch.no_grad() + def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None: + """Make sure local shards of parameters are in expected datatypes + + For FP8 params, FP32 values are cast into FP8 using per-param + scaling factors and per-param amaxes are computed and reduced. + + """ + + # Just call base class function if there are no FP8 tensors + num_fp8_params = sum(1 for param in self.parameters() if _is_fp8_tensor(param)) + if num_fp8_params == 0: + super()._check_params_shard_dtypes(params_buckets) + return + + # FP8 scaling factors + amaxes = [] + scales = [] + scale_invs = [] + i = -1 + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + i += 1 + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) + scales.append(fp8_meta.scale[fp8_meta_index].view(1)) + scale_invs.append(param._scale_inv.view(1)) + + # Update cached scale-inverses + packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.reciprocal(packed_scales, out=packed_scales) + _multi_tensor_copy( + packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Cast local data to FP8 + fp8_params_shards = dict() + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + + # FP8 metadata + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + fp8_dtype = param._fp8_dtype + + # Iterate through fragments with local data + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) + + # Get bucket containing fragment + bucket_id = fragment.bucket_id + if bucket_id not in params_buckets: + continue + state_bucket = self.state["buckets"][bucket_id] + param_bucket = params_buckets[bucket_id] + if state_bucket.param_sync_dtype != torch.uint8: + continue + + # Allocate FP8 buffer if needed + if bucket_id not in fp8_params_shards: + fp8_params_shards[bucket_id] = torch.empty_like(param_bucket.params_shard, dtype=torch.uint8) + + # FP8 cast and amax + fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) + fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) + cast_to_fp8( + fp32_fragment, fp8_meta, fp8_meta_index, fp8_dtype, out=fp8_fragment, + ) + + # Update param shards with FP8 buffers + for bucket_id, params_shard in fp8_params_shards.items(): + params_buckets[bucket_id].params_shard = params_shard + + # Reduce amaxes + # Note: Assume each param has a separate amax + packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.distributed.all_reduce( + packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + ) + _multi_tensor_copy( + packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Handle any remaining dtype conversions + super()._check_params_shard_dtypes(params_buckets) + def sharded_state_dict(self, model_sharded_state_dict): optimizer_state_dict = self.state_dict() diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index 2c424f72e411..ebf892927723 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -22,6 +22,7 @@ cast_all, cast_tensor, ) +from nemo.utils.dtype import str_to_dtype from nemo.utils.nemo_logging import Logger as _Logger from nemo.utils.nemo_logging import LogMode as logging_mode diff --git a/nemo/utils/dtype.py b/nemo/utils/dtype.py new file mode 100644 index 000000000000..b17d2039fdb2 --- /dev/null +++ b/nemo/utils/dtype.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union + +import torch + +_str_to_dtype: Dict[str, torch.dtype] = dict( + float32=torch.float32, + float=torch.float32, + float64=torch.float64, + double=torch.float64, + float16=torch.float16, + half=torch.float16, + bfloat16=torch.bfloat16, + bf16=torch.bfloat16, + uint8=torch.uint8, + byte=torch.uint8, + int8=torch.int8, + char=torch.int8, + int16=torch.int16, + short=torch.int16, + int32=torch.int32, + int=torch.int32, + int64=torch.int64, + long=torch.int64, + bool=torch.bool, +) + + +def str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: + """Convert a data type name to a PyTorch data type""" + if isinstance(dtype, torch.dtype): + return dtype + name = str(dtype).strip().lower() + if name.startswith("torch."): + name = name.replace("torch.", "", 1) + if name.startswith("fp"): + name = name.replace("fp", "float", 1) + if name not in _str_to_dtype: + raise ValueError(f"Unrecognized dtype ({name})") + return _str_to_dtype[name] From 12fd07ca885006cfc75b1d342af7b3c02d51d54e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 18 Dec 2023 20:43:06 +0000 Subject: [PATCH 2/6] Removed unused import Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index dbe31f57b6b4..10008ef37bcf 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -26,7 +26,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedTensor from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state -from nemo.utils import logging, str_to_dtype +from nemo.utils import str_to_dtype # Check if Transformer Engine has FP8 tensor class HAVE_TE_FP8TENSOR = False From dd22a18a3058d1e08f25188466a87c9fa2d19b82 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 18 Dec 2023 20:43:31 +0000 Subject: [PATCH 3/6] Update PyTorch container in Jenkins pipeline Signed-off-by: Tim Moon --- Jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index e7e607b68c02..a211662c6deb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,7 +1,7 @@ pipeline { agent { docker { - image 'nvcr.io/nvidia/pytorch:23.09-py3' + image 'nvcr.io/nvidia/pytorch:23.11-py3' args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache:/root/.cache --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1' } } @@ -2983,7 +2983,7 @@ pipeline { sh "rm -rf examples/nlp/language_modeling/bert_pretrain_results" sh "rm -rf examples/nlp/language_modeling/bert_index_mappings" } - } + } stage('L2: Megatron RETRO Pretraining and Resume Training') { when { anyOf { From f40a38aa33cc50c802a3f39d319d6f22e59f8241 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 21 Dec 2023 01:24:22 +0000 Subject: [PATCH 4/6] Use custom container with Apex bugfixes See https://github.com/NVIDIA/apex/pull/1760. Signed-off-by: Tim Moon --- Jenkinsfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index a211662c6deb..102736c41ff3 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,7 +1,8 @@ pipeline { agent { docker { - image 'nvcr.io/nvidia/pytorch:23.11-py3' + // image 'nvcr.io/nvidia/pytorch:23.11-py3' + image 'gitlab-master.nvidia.com/tmoon/containers/nemo:23.11-apex-bugfix' args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache:/root/.cache --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1' } } From 2f8fc4efdb20f851c90e9a4219e52156801a4826 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 27 Dec 2023 22:45:12 +0000 Subject: [PATCH 5/6] Upgrade to PyTorch 23.11 container Signed-off-by: Tim Moon --- Dockerfile | 11 ++++++----- Jenkinsfile | 15 +++++++++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 3cbed0d7d3bb..befba9f0cb33 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:23.10-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:23.11-py3 # build an image that includes only the nemo dependencies, ensures that dependencies # are included first for optimal caching, and useful for building a development @@ -50,15 +50,16 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ git checkout e122536b7645edcb7ebf099b5c92a443f7dbf8e7 && \ pip install . -# Distributed Adam support for multiple dtypes -RUN git clone https://github.com/NVIDIA/apex.git && \ +# Apex bugfix for PyTorch 23.11 container: https://github.com/NVIDIA/apex/pull/1760 +RUN git clone https://github.com/timmoon10/apex.git && \ cd apex && \ - git checkout a2f6683b10fb4c29ab57c9e3d16957db76a8a5ba && \ + git checkout memory-efficient-layer-norm-bugfix && \ pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ +# Transformer Engine 1.2.0 RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ cd TransformerEngine && \ - git fetch origin ff760a9d838ae4617600cccb22131d0359ce0296 && \ + git fetch origin 4f9662fbe621671f5f905e772fc1138953af77f6 && \ git checkout FETCH_HEAD && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . diff --git a/Jenkinsfile b/Jenkinsfile index 5f424d79ca90..5b4f34cd3b05 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,7 +1,7 @@ pipeline { agent { docker { - image 'nvcr.io/nvidia/pytorch:23.10-py3' + image 'nvcr.io/nvidia/pytorch:23.11-py3' args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache:/root/.cache --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1' } } @@ -57,17 +57,28 @@ pipeline { } } + // Transformer Engine 1.2.0 stage('Transformer Engine installation') { steps { sh 'git clone https://github.com/NVIDIA/TransformerEngine.git && \ cd TransformerEngine && \ - git fetch origin cf6fc898286e4ad347ff88925c88663324e2b87d && \ + git fetch origin 4f9662fbe621671f5f905e772fc1138953af77f6 && \ git checkout FETCH_HEAD && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install .' } } + // Apex bugfix for PyTorch 23.11 container: https://github.com/NVIDIA/apex/pull/1760 + stage('Apex installation') { + steps { + sh 'git clone https://github.com/timmoon10/apex.git && \ + cd apex && \ + git checkout memory-efficient-layer-norm-bugfix && \ + cp -R apex /usr/local/lib/python3.10/dist-packages' + } + } + // pip package should be working with main, if not we can update the commit here // until the pip package is updated // stage('Megatron Core installation') { From b6890f4467da6266e2492a1063ebf118452a521e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Jan 2024 10:37:57 -0800 Subject: [PATCH 6/6] Update Apex commit Signed-off-by: Tim Moon --- Dockerfile | 4 ++-- Jenkinsfile | 4 ++-- README.rst | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index befba9f0cb33..2823f120862b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,9 +51,9 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ pip install . # Apex bugfix for PyTorch 23.11 container: https://github.com/NVIDIA/apex/pull/1760 -RUN git clone https://github.com/timmoon10/apex.git && \ +RUN git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout memory-efficient-layer-norm-bugfix && \ + git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 && \ pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ # Transformer Engine 1.2.0 diff --git a/Jenkinsfile b/Jenkinsfile index 5b4f34cd3b05..8b88ba9fb1ac 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -72,9 +72,9 @@ pipeline { // Apex bugfix for PyTorch 23.11 container: https://github.com/NVIDIA/apex/pull/1760 stage('Apex installation') { steps { - sh 'git clone https://github.com/timmoon10/apex.git && \ + sh 'git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout memory-efficient-layer-norm-bugfix && \ + git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 && \ cp -R apex /usr/local/lib/python3.10/dist-packages' } } diff --git a/README.rst b/README.rst index bc68d67a3788..105e639e877d 100644 --- a/README.rst +++ b/README.rst @@ -48,8 +48,8 @@ Latest News :alt: H200-NeMo-performance :width: 600 -NeMo Framework has been updated with state-of-the-art features, -such as FSDP, Mixture-of-Experts, and RLHF with TensorRT-LLM to provide speedups up to 4.2x for Llama-2 pre-training on H200. +NeMo Framework has been updated with state-of-the-art features, +such as FSDP, Mixture-of-Experts, and RLHF with TensorRT-LLM to provide speedups up to 4.2x for Llama-2 pre-training on H200. **All of these features will be available in an upcoming release.** @@ -325,7 +325,7 @@ To install Apex, run git clone https://github.com/NVIDIA/apex.git cd apex - git checkout a2f6683b10fb4c29ab57c9e3d16957db76a8a5ba + git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies.