diff --git a/Dockerfile b/Dockerfile index 094c8fff408e..6a5c48bee4c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -69,10 +69,10 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ git checkout 27cbe46714a50c43ed290f1b1472db8d2780c55c && \ pip install . -# Apex bugfix for PyTorch 23.11 container: https://github.com/NVIDIA/apex/pull/1760 +# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 RUN git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 && \ + git checkout b496d85fb88a801d8e680872a12822de310951fd && \ pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ # Transformer Engine 1.2.0 diff --git a/README.rst b/README.rst index 78396d80dc45..44e5df6b7488 100644 --- a/README.rst +++ b/README.rst @@ -326,7 +326,7 @@ To install Apex, run git clone https://github.com/NVIDIA/apex.git cd apex - git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 + git checkout b496d85fb88a801d8e680872a12822de310951fd pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies. diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index d7130788f9d8..56f5d146e964 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -784,7 +784,7 @@ def configure_optimizers(self): if self.with_distributed_adam: # Initialize param buckets if explicitly provided - if getattr(self, 'distributed_adam_buckets', None): + if getattr(self, 'distributed_adam_buckets', None) is not None: for bucket in self.distributed_adam_buckets: self._optimizer.init_params_bucket(bucket) self._optimizer.init_params_bucket(self.parameters()) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ecb5db742785..8158c88a2522 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -70,6 +70,7 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.core.neural_types import ChannelType, NeuralType from nemo.utils import logging +from nemo.utils.te_utils import is_float8tensor try: import apex.transformer.pipeline_parallel.utils @@ -483,8 +484,18 @@ def configure_optimizers(self): param._disable_overlap_grad_sync = True # Initialize parameter buckets for overlapped grad and param syncs - # Note: Params with disabled overlapping are put in the - # last param bucket + # Note: Params with disabled overlapping and params in the + # first layer are put together in a bucket. If FP8 tensors + # are detected, those are also put in the first layer's + # bucket. + def make_parameter_bucket(module: torch.nn.Module) -> List[torch.nn.Parameter]: + bucket = [ + param for param in module.parameters() if not getattr(param, '_disable_overlap_grad_sync', False) + ] + if any(is_float8tensor(param) for param in bucket): + bucket = list(filter(is_float8tensor, bucket)) + return bucket + buckets = [] if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: # Initialize a bucket for each virtual pipeline stage @@ -493,13 +504,7 @@ def configure_optimizers(self): module = module.module stage_bucket = [] layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers - for layer in layers: - stage_bucket.extend( - p - for p in layer.parameters() - if not getattr(p, '_disable_overlap_grad_sync', False) and p.requires_grad - ) - buckets.append(stage_bucket) + buckets.extend(make_parameter_bucket(layer) for layer in layers) else: # Initialize a bucket for each Transformer layer modules = self.model if isinstance(self.model, list) else [self.model] @@ -507,21 +512,10 @@ def configure_optimizers(self): if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers - for layer in layers: - buckets.append( - [ - p - for p in layer.parameters() - if not getattr(p, '_disable_overlap_grad_sync', False) and p.requires_grad - ] - ) + buckets.extend(make_parameter_bucket(layer) for layer in layers) buckets.reverse() - used_params = set() - for bucket in buckets: - used_params.update(bucket) - remaining_params = [p for p in self.parameters() if p not in used_params and p.requires_grad] - if remaining_params: - buckets.append(remaining_params) + used_params = set(itertools.chain.from_iterable(buckets)) + buckets[-1].extend(p for p in self.parameters() if p not in used_params) self.distributed_adam_buckets = buckets return super().configure_optimizers() diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 24ffcc17e6c2..a2316dabb023 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import itertools from typing import Callable, Dict, Iterable, Optional, Union import torch @@ -25,23 +26,10 @@ from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace from megatron.core.dist_checkpointing.mapping import ShardedTensor from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state +from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 from nemo.utils import str_to_dtype - -# Check if Transformer Engine has FP8 tensor class -HAVE_TE_FP8TENSOR = False -try: - from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 - from transformer_engine.pytorch.float8_tensor import Float8Tensor - - HAVE_TE_FP8TENSOR = True -except (ImportError, ModuleNotFoundError): - # Float8Tensor not found - pass - - -def _is_fp8_tensor(tensor: torch.Tensor) -> bool: - return HAVE_TE_FP8TENSOR and isinstance(tensor, Float8Tensor) +from nemo.utils.te_utils import is_float8tensor class MegatronDistributedFusedAdam(DistributedFusedAdam): @@ -92,21 +80,6 @@ def __init__( # Construct distributed optimizer super().__init__(param_groups, **kwargs) - # Initialize weights that require FP32 grads - if self.dtype != torch.float32 or self.grad_sync_dtype != torch.float32: - fp32_params = [] - for param_group in param_groups: - fp32_params.extend( - filter(lambda param: getattr(param, '_with_fp32_optimizer', False), param_group['params']) - ) - if fp32_params: - assert ( - self.dtype == torch.float32 - ), f'Param requires FP32 state but optimizer is initialized with {self.dtype}' - self.init_params_bucket( - fp32_params, grad_sync_dtype=torch.float32, - ) - def _broadcast_params(self) -> None: # Assume params have already been synchronized pass @@ -158,25 +131,22 @@ def init_params( return # Initialize FP8 and non-FP8 tensors separately - if any(_is_fp8_tensor(param) for param in params): + if any(is_float8tensor(param) for param in params): super().init_params( - filter(_is_fp8_tensor, params), param_sync_dtype=torch.uint8, **kwargs, + filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs, ) super().init_params( params, param_sync_dtype=param_sync_dtype, **kwargs, ) def init_params_bucket( - self, params: Iterable[torch.nn.Parameter], param_sync_dtype: Optional[torch.dtype] = None, **kwargs, + self, + params: Iterable[torch.nn.Parameter], + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, ) -> None: - """Initialize optimizer state for parameters in one effective bucket - - If any FP8 params are detected, all non-FP8 params are removed - from the bucket and their overlapped grad syncs are disabled. - This assumes that weight matrices are FP8 params and that - non-FP8 params are small (e.g. biases and layer norm params). - - """ + """Initialize optimizer state for parameters in one effective bucket""" # Ignore parameters that have already been initialized if isinstance(params, torch.Tensor): @@ -185,18 +155,79 @@ def init_params_bucket( if not params: return - # Ignore non-FP8 params if there are any FP8 params - if any(_is_fp8_tensor(param) for param in params): - for param in params: - if not _is_fp8_tensor(param): - param._disable_overlap_grad_sync = True - params = filter(_is_fp8_tensor, params) - param_sync_dtype = torch.uint8 + # Initialize parameters with FP32 grads + fp32_params = [] + remaining_params = [] + for param in params: + if getattr(param, '_with_fp32_optimizer', False): + fp32_params.append(param) + else: + remaining_params.append(param) + params = remaining_params + start_bucket_id = len(self.state["buckets"]) + super().init_params_bucket( + fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs, + ) + end_bucket_id = len(self.state["buckets"]) + fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] - # Initialize parameter buckets + # Initialize FP8 parameters + fp8_params = [] + remaining_params = [] + for param in params: + if is_float8tensor(param): + fp8_params.append(param) + else: + remaining_params.append(param) + params = remaining_params + start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - params, param_sync_dtype=param_sync_dtype, **kwargs, + fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs, ) + end_bucket_id = len(self.state["buckets"]) + fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] + + # Initialize remaining parameters as usual + normal_buckets = [] + start_bucket_id = len(self.state["buckets"]) + super().init_params_bucket( + params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs, + ) + end_bucket_id = len(self.state["buckets"]) + normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] + + def add_param_to_bucket(param: torch.nn.Parameter, bucket: self.StateBucket,) -> None: + """Add trivial param fragment to bucket""" + param_fragments = self.state[param]["fragments"] + param_group_id = param_fragments[0].param_group_id + param_id = param_fragments[0].param_id + bucket_id = bucket.fragments[0].bucket_id + param_size = param.numel() + bucket_size = bucket.bucket_size + fragment = self.ParameterFragment( + param_group_id=param_group_id, + param_id=param_id, + bucket_id=bucket_id, + param_range=(param_size, param_size), + bucket_range=(bucket_size, bucket_size), + in_local_shard=False, + shard_range=None, + shard_bucket_range=None, + shard_param_range=None, + ) + param_fragments.append(fragment) + bucket.fragments.append(fragment) + + # Make sure all added buckets depend on provided params + for bucket in fp32_buckets: + for param in itertools.chain(fp8_params, params): + add_param_to_bucket(param, bucket) + for bucket in fp8_buckets: + for param in itertools.chain(fp32_params, params): + add_param_to_bucket(param, bucket) + for bucket in normal_buckets: + for param in itertools.chain(fp32_params, fp8_params): + add_param_to_bucket(param, bucket) def _init_param_state( self, @@ -216,7 +247,7 @@ def _init_param_state( """ # Initialize non-FP8 params as usual - if not _is_fp8_tensor(param): + if not is_float8tensor(param): super()._init_param_state( param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, ) @@ -250,10 +281,10 @@ def init_param_buffer(self) -> None: buffer_sizes = collections.defaultdict(lambda: 0) for bucket in self.state["buckets"]: dtypes = bucket.dtypes() - buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes],) + buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes]) for dtypes, buffer_size in buffer_sizes.items(): _, _, param_sync_dtype = dtypes - self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device,) + self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device) # Figure out corresponding positions in params and param buffer params = list(self.parameters()) @@ -275,7 +306,7 @@ def init_param_buffer(self) -> None: "Attempted to change a parameter with device={param.device} " f"into a buffer view with device={param_buffer_view.device}" ) - if _is_fp8_tensor(param): + if is_float8tensor(param): param_flat_views.append(param._data.detach().view(-1)) else: if param_buffer_view.dtype != param.dtype: @@ -293,8 +324,8 @@ def init_param_buffer(self) -> None: # Make all params a view into the param buffer for param, buffer_view in zip(params, param_buffer_views): - if _is_fp8_tensor(param): - param._data.data = buffer_view.view(param.size()) + if is_float8tensor(param): + param._data = buffer_view.view(param.size()) else: param.data = buffer_view.view(param.size()) @@ -372,7 +403,7 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet param_bucket = self._params_buckets[bucket_id] param = self.parameter(fragment) buffer_in = param_bucket.params_bucket[bucket_start:bucket_end] - if _is_fp8_tensor(param): + if is_float8tensor(param): # Copy into FP8 params's data buffer assert ( param_bucket.params_bucket.dtype == torch.uint8 @@ -406,8 +437,10 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet # Update transpose caches params = set(self.parameter(fragment) for fragment in fragments) for param in params: - if _is_fp8_tensor(param): + if is_float8tensor(param): + param._reset_caches() param.transpose(update_cache=True) + param._lazy_transpose_cache = True @torch.no_grad() def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None: @@ -419,94 +452,99 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA """ # Just call base class function if there are no FP8 tensors - num_fp8_params = sum(1 for param in self.parameters() if _is_fp8_tensor(param)) + num_fp8_params = sum(1 for param in self.parameters() if is_float8tensor(param)) if num_fp8_params == 0: super()._check_params_shard_dtypes(params_buckets) return - # FP8 scaling factors - amaxes = [] - scales = [] - scale_invs = [] - i = -1 - for param in self.parameters(): - if not _is_fp8_tensor(param): - continue - i += 1 - fp8_meta = param._fp8_meta["scaling_fwd"] - fp8_meta_index = param._fp8_meta_index - amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) - scales.append(fp8_meta.scale[fp8_meta_index].view(1)) - scale_invs.append(param._scale_inv.view(1)) - - # Update cached scale-inverses - packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) - packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] - _multi_tensor_copy( - scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, - ) - torch.reciprocal(packed_scales, out=packed_scales) - _multi_tensor_copy( - packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, - ) - # Cast local data to FP8 fp8_params_shards = dict() - for param in self.parameters(): - if not _is_fp8_tensor(param): + for bucket_id, param_bucket in params_buckets.items(): + state_bucket = self.state["buckets"][bucket_id] + if state_bucket.param_sync_dtype != torch.uint8: continue - # FP8 metadata - fp8_meta = param._fp8_meta["scaling_fwd"] - fp8_meta_index = param._fp8_meta_index - fp8_dtype = param._fp8_dtype + # Initialize FP8 buffer for param sync + params_shard = param_bucket.params_shard + if self.contiguous_param_buffer: + shard_size = state_bucket.shard_size + buffer_offset = state_bucket.contiguous_buffer_offset + buffer_start = buffer_offset + self.distributed_rank * shard_size + buffer_end = buffer_start + shard_size + param_buffer = self._param_buffers[state_bucket.dtypes()] + fp8_params_shard = param_buffer[buffer_start:buffer_end] + else: + fp8_params_shard = torch.empty_like(params_shard, dtype=torch.uint8) + param_bucket.params_shard = fp8_params_shard - # Iterate through fragments with local data - for fragment in self.state[param]["fragments"]: + # Cast param fragments to FP8 + for fragment in self.state["buckets"][bucket_id].fragments: + param = self.parameter(fragment) + if not is_float8tensor(param): + continue if not fragment.in_local_shard: continue shard_start, shard_end = fragment.shard_range if shard_end <= shard_start: continue shard_range = slice(shard_start, shard_end) - - # Get bucket containing fragment - bucket_id = fragment.bucket_id - if bucket_id not in params_buckets: - continue - state_bucket = self.state["buckets"][bucket_id] - param_bucket = params_buckets[bucket_id] - if state_bucket.param_sync_dtype != torch.uint8: - continue - - # Allocate FP8 buffer if needed - if bucket_id not in fp8_params_shards: - fp8_params_shards[bucket_id] = torch.empty_like(param_bucket.params_shard, dtype=torch.uint8) - - # FP8 cast and amax - fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) - fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) cast_to_fp8( - fp32_fragment, fp8_meta, fp8_meta_index, fp8_dtype, out=fp8_fragment, + params_shard[shard_range].view(1, -1), + param._fp8_meta["scaling_fwd"], + param._fp8_meta_index, + param._fp8_dtype, + out=fp8_params_shard[shard_range].view(1, -1), ) - # Update param shards with FP8 buffers - for bucket_id, params_shard in fp8_params_shards.items(): - params_buckets[bucket_id].params_shard = params_shard + # Update FP8 scaling factors when all buckets have processed + if getattr(self, "_check_params_shard_dtypes_progress", None) is None: + self._check_params_shard_dtypes_progress = [] + self._check_params_shard_dtypes_progress.extend(params_buckets.keys()) + if len(self._check_params_shard_dtypes_progress) == len(self.state["buckets"]): + assert len(set(self._check_params_shard_dtypes_progress)) == len(self.state["buckets"]) + + # FP8 scaling factors + amaxes = [] + scales = [] + scale_invs = [] + i = -1 + for param in self.parameters(): + if not is_float8tensor(param): + continue + i += 1 + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) + scales.append(fp8_meta.scale[fp8_meta_index].view(1)) + scale_invs.append(param._scale_inv.view(1)) + + # Update cached scale-inverses + packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.reciprocal(packed_scales, out=packed_scales) + _multi_tensor_copy( + packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Reduce amaxes + # Note: Assume each param has a separate amax + packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.distributed.all_reduce( + packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + ) + _multi_tensor_copy( + packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + ) - # Reduce amaxes - # Note: Assume each param has a separate amax - packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) - packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] - _multi_tensor_copy( - amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, - ) - torch.distributed.all_reduce( - packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, - ) - _multi_tensor_copy( - packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, - ) + # Reset + self._check_params_shard_dtypes_progress = None # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) diff --git a/nemo/utils/te_utils.py b/nemo/utils/te_utils.py new file mode 100644 index 000000000000..8f073e211681 --- /dev/null +++ b/nemo/utils/te_utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# Check if Transformer Engine has Float8Tensor class +HAVE_TE_FLOAT8TENSOR = False +try: + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE_FLOAT8TENSOR = True +except (ImportError, ModuleNotFoundError): + # Float8Tensor not found + pass + + +def is_float8tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Transformer Engine Float8Tensor""" + return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)