From 5900d4edc41d8c2224a30b0abade66704628f4f3 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 18 Jan 2024 12:01:08 -0800 Subject: [PATCH 01/10] Only reduce amaxes after fp8 cast for last distopt bucket Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 115 +++++++++++++++------------- 1 file changed, 61 insertions(+), 54 deletions(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 24ffcc17e6c2..d24d6840fecc 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -424,45 +424,13 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA super()._check_params_shard_dtypes(params_buckets) return - # FP8 scaling factors - amaxes = [] - scales = [] - scale_invs = [] - i = -1 - for param in self.parameters(): - if not _is_fp8_tensor(param): - continue - i += 1 - fp8_meta = param._fp8_meta["scaling_fwd"] - fp8_meta_index = param._fp8_meta_index - amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) - scales.append(fp8_meta.scale[fp8_meta_index].view(1)) - scale_invs.append(param._scale_inv.view(1)) - - # Update cached scale-inverses - packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) - packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] - _multi_tensor_copy( - scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, - ) - torch.reciprocal(packed_scales, out=packed_scales) - _multi_tensor_copy( - packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, - ) - # Cast local data to FP8 fp8_params_shards = dict() - for param in self.parameters(): - if not _is_fp8_tensor(param): - continue - - # FP8 metadata - fp8_meta = param._fp8_meta["scaling_fwd"] - fp8_meta_index = param._fp8_meta_index - fp8_dtype = param._fp8_dtype - - # Iterate through fragments with local data - for fragment in self.state[param]["fragments"]: + for bucket_id, param_bucket in params_buckets.items(): + for fragment in self.state["buckets"][bucket_id].fragments: + param = self.parameter(fragment) + if not _is_fp8_tensor(param): + continue if not fragment.in_local_shard: continue shard_start, shard_end = fragment.shard_range @@ -470,12 +438,13 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA continue shard_range = slice(shard_start, shard_end) + # FP8 metadata + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + fp8_dtype = param._fp8_dtype + # Get bucket containing fragment - bucket_id = fragment.bucket_id - if bucket_id not in params_buckets: - continue state_bucket = self.state["buckets"][bucket_id] - param_bucket = params_buckets[bucket_id] if state_bucket.param_sync_dtype != torch.uint8: continue @@ -494,19 +463,57 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA for bucket_id, params_shard in fp8_params_shards.items(): params_buckets[bucket_id].params_shard = params_shard - # Reduce amaxes - # Note: Assume each param has a separate amax - packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) - packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] - _multi_tensor_copy( - amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, - ) - torch.distributed.all_reduce( - packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, - ) - _multi_tensor_copy( - packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, - ) + # Update FP8 scaling factors when all buckets have processed + if getattr(self, "_check_params_shard_dtypes_progress", None) is None: + self._check_params_shard_dtypes_progress = [] + self._check_params_shard_dtypes_progress.extend(params_buckets.keys()) + if len(self._check_params_shard_dtypes_progress) == len(self.state["buckets"]): + assert ( + len(set(self._check_params_shard_dtypes_progress)) == len(self.state["buckets"]) + ) + + # FP8 scaling factors + amaxes = [] + scales = [] + scale_invs = [] + i = -1 + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + i += 1 + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) + scales.append(fp8_meta.scale[fp8_meta_index].view(1)) + scale_invs.append(param._scale_inv.view(1)) + + # Update cached scale-inverses + packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.reciprocal(packed_scales, out=packed_scales) + _multi_tensor_copy( + packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Reduce amaxes + # Note: Assume each param has a separate amax + packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.distributed.all_reduce( + packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + ) + _multi_tensor_copy( + packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Reset + self._check_params_shard_dtypes_progress = None # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) From 11f709d8c1250a0eb346f7e8c70c097180e59d33 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 18 Jan 2024 18:57:55 -0800 Subject: [PATCH 02/10] Handle case with FP8 and contiguous param buffer Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 51 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index d24d6840fecc..0253eeb62b94 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -427,6 +427,29 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA # Cast local data to FP8 fp8_params_shards = dict() for bucket_id, param_bucket in params_buckets.items(): + state_bucket = self.state["buckets"][bucket_id] + if state_bucket.param_sync_dtype != torch.uint8: + continue + if not all( + _is_fp8_tensor(self.parameter(fragment)) + for fragment in state_bucket.fragments + ): + continue + + # Initialize FP8 buffer for param sync + params_shard = param_bucket.params_shard + if self.contiguous_param_buffer: + shard_size = state_bucket.shard_size + buffer_offset = state_bucket.contiguous_buffer_offset + buffer_start = buffer_offset + self.distributed_rank * shard_size + buffer_end = buffer_start + shard_size + param_buffer = self._param_buffers[state_bucket.dtypes()] + fp8_params_shard = param_buffer[buffer_start:buffer_end] + else: + fp8_params_shard = torch.empty_like(params_shard, dtype=torch.uint8) + param_bucket.params_shard = fp8_params_shard + + # Cast param fragments to FP8 for fragment in self.state["buckets"][bucket_id].fragments: param = self.parameter(fragment) if not _is_fp8_tensor(param): @@ -437,32 +460,14 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA if shard_end <= shard_start: continue shard_range = slice(shard_start, shard_end) - - # FP8 metadata - fp8_meta = param._fp8_meta["scaling_fwd"] - fp8_meta_index = param._fp8_meta_index - fp8_dtype = param._fp8_dtype - - # Get bucket containing fragment - state_bucket = self.state["buckets"][bucket_id] - if state_bucket.param_sync_dtype != torch.uint8: - continue - - # Allocate FP8 buffer if needed - if bucket_id not in fp8_params_shards: - fp8_params_shards[bucket_id] = torch.empty_like(param_bucket.params_shard, dtype=torch.uint8) - - # FP8 cast and amax - fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) - fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) cast_to_fp8( - fp32_fragment, fp8_meta, fp8_meta_index, fp8_dtype, out=fp8_fragment, + params_shard[shard_range].view(1, -1), + param._fp8_meta["scaling_fwd"], + param._fp8_meta_index, + param._fp8_dtype, + out=fp8_params_shard[shard_range].view(1, -1), ) - # Update param shards with FP8 buffers - for bucket_id, params_shard in fp8_params_shards.items(): - params_buckets[bucket_id].params_shard = params_shard - # Update FP8 scaling factors when all buckets have processed if getattr(self, "_check_params_shard_dtypes_progress", None) is None: self._check_params_shard_dtypes_progress = [] From 7e16a0982626d5c7afbd5ac2c8dff75e41129038 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 18 Jan 2024 22:02:24 -0800 Subject: [PATCH 03/10] Support distopt buckets with mixed dtypes Signed-off-by: Tim Moon --- .../language_modeling/megatron_base_model.py | 2 +- .../language_modeling/megatron_gpt_model.py | 6 +- nemo/core/optim/distributed_adam.py | 122 +++++++++++++----- 3 files changed, 94 insertions(+), 36 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 8e718e06c260..1ede5b85dc80 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -765,7 +765,7 @@ def configure_optimizers(self): if self.with_distributed_adam: # Initialize param buckets if explicitly provided - if getattr(self, 'distributed_adam_buckets', None): + if getattr(self, 'distributed_adam_buckets', None) is not None: for bucket in self.distributed_adam_buckets: self._optimizer.init_params_bucket(bucket) self._optimizer.init_params_bucket(self.parameters()) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 8790312f3e62..20a62092bc57 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -449,8 +449,8 @@ def configure_optimizers(self): param._disable_overlap_grad_sync = True # Initialize parameter buckets for overlapped grad and param syncs - # Note: Params with disabled overlapping are put in the - # last param bucket + # Note: Params with disabled overlapping and params in the + # first layer are put together in a bucket. buckets = [] if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: # Initialize a bucket for each virtual pipeline stage @@ -476,6 +476,8 @@ def configure_optimizers(self): [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] ) buckets.reverse() + used_params = set(itertools.chain.from_iterable(buckets)) + buckets[-1].extend(p for p in self.parameters() if p not in used_params) self.distributed_adam_buckets = buckets return super().configure_optimizers() diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 0253eeb62b94..f7c93ab3a9c8 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import itertools from typing import Callable, Dict, Iterable, Optional, Union import torch @@ -92,21 +93,6 @@ def __init__( # Construct distributed optimizer super().__init__(param_groups, **kwargs) - # Initialize weights that require FP32 grads - if self.dtype != torch.float32 or self.grad_sync_dtype != torch.float32: - fp32_params = [] - for param_group in param_groups: - fp32_params.extend( - filter(lambda param: getattr(param, '_with_fp32_optimizer', False), param_group['params']) - ) - if fp32_params: - assert ( - self.dtype == torch.float32 - ), f'Param requires FP32 state but optimizer is initialized with {self.dtype}' - self.init_params_bucket( - fp32_params, grad_sync_dtype=torch.float32, - ) - def _broadcast_params(self) -> None: # Assume params have already been synchronized pass @@ -167,16 +153,13 @@ def init_params( ) def init_params_bucket( - self, params: Iterable[torch.nn.Parameter], param_sync_dtype: Optional[torch.dtype] = None, **kwargs, + self, + params: Iterable[torch.nn.Parameter], + grad_sync_dtype: Optional[torch.dtype] = None, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, ) -> None: - """Initialize optimizer state for parameters in one effective bucket - - If any FP8 params are detected, all non-FP8 params are removed - from the bucket and their overlapped grad syncs are disabled. - This assumes that weight matrices are FP8 params and that - non-FP8 params are small (e.g. biases and layer norm params). - - """ + """Initialize optimizer state for parameters in one effective bucket""" # Ignore parameters that have already been initialized if isinstance(params, torch.Tensor): @@ -185,18 +168,91 @@ def init_params_bucket( if not params: return - # Ignore non-FP8 params if there are any FP8 params - if any(_is_fp8_tensor(param) for param in params): - for param in params: - if not _is_fp8_tensor(param): - param._disable_overlap_grad_sync = True - params = filter(_is_fp8_tensor, params) - param_sync_dtype = torch.uint8 + # Initialize parameters with FP32 grads + fp32_params = [] + remaining_params = [] + for param in params: + if getattr(param, '_with_fp32_optimizer', False): + fp32_params.append(param) + else: + remaining_params.append(param) + params = remaining_params + start_bucket_id = len(self.state["buckets"]) + super().init_params_bucket( + fp32_params, + grad_sync_dtype=torch.float32, + param_sync_dtype=param_sync_dtype, + **kwargs, + ) + end_bucket_id = len(self.state["buckets"]) + fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] + + # Initialize FP8 parameters + fp8_params = [] + remaining_params = [] + for param in params: + if _is_fp8_tensor(param): + fp8_params.append(param) + else: + remaining_params.append(param) + params = remaining_params + start_bucket_id = len(self.state["buckets"]) + super().init_params_bucket( + fp8_params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=torch.uint8, + **kwargs, + ) + end_bucket_id = len(self.state["buckets"]) + fp8_buckets =self.state["buckets"][start_bucket_id:end_bucket_id] - # Initialize parameter buckets + # Initialize remaining parameters as usual + normal_buckets = [] + start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - params, param_sync_dtype=param_sync_dtype, **kwargs, + params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + **kwargs, ) + end_bucket_id = len(self.state["buckets"]) + normal_buckets =self.state["buckets"][start_bucket_id:end_bucket_id] + + def add_param_to_bucket( + param: torch.nn.Parameter, + bucket: self.StateBucket, + ) -> None: + """Add trivial param fragment to bucket""" + param_fragments = self.state[param]["fragments"] + param_group_id = param_fragments[0].param_group_id + param_id = param_fragments[0].param_id + bucket_id = bucket.fragments[0].bucket_id + param_size = param.numel() + bucket_size = bucket.bucket_size + fragment = self.ParameterFragment( + param_group_id=param_group_id, + param_id=param_id, + bucket_id=bucket_id, + param_range=(param_size, param_size), + bucket_range=(bucket_size, bucket_size), + in_local_shard=False, + shard_range=None, + shard_bucket_range=None, + shard_param_range=None, + ) + param_fragments.append(fragment) + bucket.fragments.append(fragment) + + # Make sure all added buckets depend on provided params + for bucket in fp32_buckets: + for param in itertools.chain(fp8_params, params): + add_param_to_bucket(param, bucket) + for bucket in fp8_buckets: + for param in itertools.chain(fp32_params, params): + add_param_to_bucket(param, bucket) + for bucket in normal_buckets: + for param in itertools.chain(fp32_params, fp8_params): + add_param_to_bucket(param, bucket) def _init_param_state( self, From 3286773a29131b57bc38a8c4bb7a10947d344f7d Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 19 Jan 2024 00:53:23 -0800 Subject: [PATCH 04/10] Fix bug where fp8 casts were being skipped Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index f7c93ab3a9c8..83146c0e898d 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -486,11 +486,6 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA state_bucket = self.state["buckets"][bucket_id] if state_bucket.param_sync_dtype != torch.uint8: continue - if not all( - _is_fp8_tensor(self.parameter(fragment)) - for fragment in state_bucket.fragments - ): - continue # Initialize FP8 buffer for param sync params_shard = param_bucket.params_shard From 2d49eec743495d164dd44caa8d9a670796565c1c Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 22 Jan 2024 12:03:24 -0800 Subject: [PATCH 05/10] Debug FP8 params with contiguous param buffer Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 83146c0e898d..e08c7e26e6ea 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -306,10 +306,10 @@ def init_param_buffer(self) -> None: buffer_sizes = collections.defaultdict(lambda: 0) for bucket in self.state["buckets"]: dtypes = bucket.dtypes() - buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes],) + buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes]) for dtypes, buffer_size in buffer_sizes.items(): _, _, param_sync_dtype = dtypes - self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device,) + self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device) # Figure out corresponding positions in params and param buffer params = list(self.parameters()) @@ -350,7 +350,7 @@ def init_param_buffer(self) -> None: # Make all params a view into the param buffer for param, buffer_view in zip(params, param_buffer_views): if _is_fp8_tensor(param): - param._data.data = buffer_view.view(param.size()) + param._data = buffer_view.view(param.size()) else: param.data = buffer_view.view(param.size()) From 97d00f83d89ad0eabd5cb07c6fb9d4973738a510 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 22 Jan 2024 16:30:12 -0800 Subject: [PATCH 06/10] Separate non-FP8 params into leftover distopt bucket Signed-off-by: Tim Moon --- .../language_modeling/megatron_gpt_model.py | 24 ++++++----- nemo/core/optim/distributed_adam.py | 41 +++++++------------ nemo/utils/te_utils.py | 29 +++++++++++++ 3 files changed, 57 insertions(+), 37 deletions(-) create mode 100644 nemo/utils/te_utils.py diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 20a62092bc57..07d3c66191bc 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -64,6 +64,7 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.core.neural_types import ChannelType, NeuralType from nemo.utils import logging +from nemo.utils.te_utils import is_float8tensor try: import apex.transformer.pipeline_parallel.utils @@ -450,7 +451,17 @@ def configure_optimizers(self): # Initialize parameter buckets for overlapped grad and param syncs # Note: Params with disabled overlapping and params in the - # first layer are put together in a bucket. + # first layer are put together in a bucket. If FP8 tensors + # are detected, those are also put in the first layer's + # bucket. + def make_parameter_bucket(module: torch.nn.Module) -> List[torch.nn.Parameter]: + bucket = [ + param for param in module.parameters() + if not getattr(param, '_disable_overlap_grad_sync', False) + ] + if any(is_float8tensor(param) for param in bucket): + bucket = list(filter(is_float8tensor, bucket)) + return bucket buckets = [] if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: # Initialize a bucket for each virtual pipeline stage @@ -459,11 +470,7 @@ def configure_optimizers(self): module = module.module stage_bucket = [] layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers - for layer in layers: - stage_bucket.extend( - p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False) - ) - buckets.append(stage_bucket) + buckets.extend(make_parameter_bucket(layer) for layer in layers) else: # Initialize a bucket for each Transformer layer modules = self.model if isinstance(self.model, list) else [self.model] @@ -471,10 +478,7 @@ def configure_optimizers(self): if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers - for layer in layers: - buckets.append( - [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] - ) + buckets.extend(make_parameter_bucket(layer) for layer in layers) buckets.reverse() used_params = set(itertools.chain.from_iterable(buckets)) buckets[-1].extend(p for p in self.parameters() if p not in used_params) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index e08c7e26e6ea..9bd74a426280 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -26,23 +26,10 @@ from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace from megatron.core.dist_checkpointing.mapping import ShardedTensor from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state +from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 from nemo.utils import str_to_dtype - -# Check if Transformer Engine has FP8 tensor class -HAVE_TE_FP8TENSOR = False -try: - from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 - from transformer_engine.pytorch.float8_tensor import Float8Tensor - - HAVE_TE_FP8TENSOR = True -except (ImportError, ModuleNotFoundError): - # Float8Tensor not found - pass - - -def _is_fp8_tensor(tensor: torch.Tensor) -> bool: - return HAVE_TE_FP8TENSOR and isinstance(tensor, Float8Tensor) +from nemo.utils.te_utils import is_float8tensor class MegatronDistributedFusedAdam(DistributedFusedAdam): @@ -144,9 +131,9 @@ def init_params( return # Initialize FP8 and non-FP8 tensors separately - if any(_is_fp8_tensor(param) for param in params): + if any(is_float8tensor(param) for param in params): super().init_params( - filter(_is_fp8_tensor, params), param_sync_dtype=torch.uint8, **kwargs, + filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs, ) super().init_params( params, param_sync_dtype=param_sync_dtype, **kwargs, @@ -191,7 +178,7 @@ def init_params_bucket( fp8_params = [] remaining_params = [] for param in params: - if _is_fp8_tensor(param): + if is_float8tensor(param): fp8_params.append(param) else: remaining_params.append(param) @@ -272,7 +259,7 @@ def _init_param_state( """ # Initialize non-FP8 params as usual - if not _is_fp8_tensor(param): + if not is_float8tensor(param): super()._init_param_state( param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, ) @@ -331,7 +318,7 @@ def init_param_buffer(self) -> None: "Attempted to change a parameter with device={param.device} " f"into a buffer view with device={param_buffer_view.device}" ) - if _is_fp8_tensor(param): + if is_float8tensor(param): param_flat_views.append(param._data.detach().view(-1)) else: if param_buffer_view.dtype != param.dtype: @@ -349,8 +336,8 @@ def init_param_buffer(self) -> None: # Make all params a view into the param buffer for param, buffer_view in zip(params, param_buffer_views): - if _is_fp8_tensor(param): - param._data = buffer_view.view(param.size()) + if is_float8tensor(param): + param._data.data = buffer_view.view(param.size()) else: param.data = buffer_view.view(param.size()) @@ -428,7 +415,7 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet param_bucket = self._params_buckets[bucket_id] param = self.parameter(fragment) buffer_in = param_bucket.params_bucket[bucket_start:bucket_end] - if _is_fp8_tensor(param): + if is_float8tensor(param): # Copy into FP8 params's data buffer assert ( param_bucket.params_bucket.dtype == torch.uint8 @@ -462,7 +449,7 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet # Update transpose caches params = set(self.parameter(fragment) for fragment in fragments) for param in params: - if _is_fp8_tensor(param): + if is_float8tensor(param): param.transpose(update_cache=True) @torch.no_grad() @@ -475,7 +462,7 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA """ # Just call base class function if there are no FP8 tensors - num_fp8_params = sum(1 for param in self.parameters() if _is_fp8_tensor(param)) + num_fp8_params = sum(1 for param in self.parameters() if is_float8tensor(param)) if num_fp8_params == 0: super()._check_params_shard_dtypes(params_buckets) return @@ -503,7 +490,7 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA # Cast param fragments to FP8 for fragment in self.state["buckets"][bucket_id].fragments: param = self.parameter(fragment) - if not _is_fp8_tensor(param): + if not is_float8tensor(param): continue if not fragment.in_local_shard: continue @@ -534,7 +521,7 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA scale_invs = [] i = -1 for param in self.parameters(): - if not _is_fp8_tensor(param): + if not is_float8tensor(param): continue i += 1 fp8_meta = param._fp8_meta["scaling_fwd"] diff --git a/nemo/utils/te_utils.py b/nemo/utils/te_utils.py new file mode 100644 index 000000000000..bc66483fc46e --- /dev/null +++ b/nemo/utils/te_utils.py @@ -0,0 +1,29 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# Check if Transformer Engine has Float8Tensor class +HAVE_TE_FLOAT8TENSOR = False +try: + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE_FLOAT8TENSOR = True +except (ImportError, ModuleNotFoundError): + # Float8Tensor not found + pass + +def is_float8tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Transformer Engine Float8Tensor""" + return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) From 31202bbdd91ffc749b837efadde168c61629b041 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 22 Jan 2024 16:47:58 -0800 Subject: [PATCH 07/10] Debug FP8 params with contiguous param buffer Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 9bd74a426280..a01dcb41c667 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -337,7 +337,7 @@ def init_param_buffer(self) -> None: # Make all params a view into the param buffer for param, buffer_view in zip(params, param_buffer_views): if is_float8tensor(param): - param._data.data = buffer_view.view(param.size()) + param._data = buffer_view.view(param.size()) else: param.data = buffer_view.view(param.size()) From 1f35eb6302013c5b4e0fac1f282c847af02bcc61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jan 2024 01:53:07 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../language_modeling/megatron_gpt_model.py | 4 +-- nemo/core/optim/distributed_adam.py | 28 +++++-------------- nemo/utils/te_utils.py | 1 + 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 525f45f8cd35..b32328580468 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -456,12 +456,12 @@ def configure_optimizers(self): # bucket. def make_parameter_bucket(module: torch.nn.Module) -> List[torch.nn.Parameter]: bucket = [ - param for param in module.parameters() - if not getattr(param, '_disable_overlap_grad_sync', False) + param for param in module.parameters() if not getattr(param, '_disable_overlap_grad_sync', False) ] if any(is_float8tensor(param) for param in bucket): bucket = list(filter(is_float8tensor, bucket)) return bucket + buckets = [] if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: # Initialize a bucket for each virtual pipeline stage diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index a01dcb41c667..82a71b3698e2 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -166,10 +166,7 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp32_params, - grad_sync_dtype=torch.float32, - param_sync_dtype=param_sync_dtype, - **kwargs, + fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] @@ -185,30 +182,21 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp8_params, - grad_sync_dtype=grad_sync_dtype, - param_sync_dtype=torch.uint8, - **kwargs, + fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs, ) end_bucket_id = len(self.state["buckets"]) - fp8_buckets =self.state["buckets"][start_bucket_id:end_bucket_id] + fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] # Initialize remaining parameters as usual normal_buckets = [] start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - params, - grad_sync_dtype=grad_sync_dtype, - param_sync_dtype=param_sync_dtype, - **kwargs, + params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs, ) end_bucket_id = len(self.state["buckets"]) - normal_buckets =self.state["buckets"][start_bucket_id:end_bucket_id] + normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] - def add_param_to_bucket( - param: torch.nn.Parameter, - bucket: self.StateBucket, - ) -> None: + def add_param_to_bucket(param: torch.nn.Parameter, bucket: self.StateBucket,) -> None: """Add trivial param fragment to bucket""" param_fragments = self.state[param]["fragments"] param_group_id = param_fragments[0].param_group_id @@ -511,9 +499,7 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA self._check_params_shard_dtypes_progress = [] self._check_params_shard_dtypes_progress.extend(params_buckets.keys()) if len(self._check_params_shard_dtypes_progress) == len(self.state["buckets"]): - assert ( - len(set(self._check_params_shard_dtypes_progress)) == len(self.state["buckets"]) - ) + assert len(set(self._check_params_shard_dtypes_progress)) == len(self.state["buckets"]) # FP8 scaling factors amaxes = [] diff --git a/nemo/utils/te_utils.py b/nemo/utils/te_utils.py index bc66483fc46e..8f073e211681 100644 --- a/nemo/utils/te_utils.py +++ b/nemo/utils/te_utils.py @@ -24,6 +24,7 @@ # Float8Tensor not found pass + def is_float8tensor(tensor: torch.Tensor) -> bool: """Check if a tensor is a Transformer Engine Float8Tensor""" return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) From 7405e47787f75d3ee7c11befb7d8318373fb182a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 2 Feb 2024 09:19:33 -0800 Subject: [PATCH 09/10] Make sure to update FP8 transpose cache Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 82a71b3698e2..8363916b71cb 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -438,6 +438,7 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet params = set(self.parameter(fragment) for fragment in fragments) for param in params: if is_float8tensor(param): + param._reset_caches() param.transpose(update_cache=True) @torch.no_grad() From 5c980bb63a45048aa793521761872e360b707f88 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 8 Feb 2024 18:51:51 +0000 Subject: [PATCH 10/10] Update Apex commit Avoid unnecessary FP8 weight transposes. Signed-off-by: Tim Moon --- Dockerfile | 4 ++-- README.rst | 2 +- nemo/core/optim/distributed_adam.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 094c8fff408e..6a5c48bee4c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -69,10 +69,10 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ git checkout 27cbe46714a50c43ed290f1b1472db8d2780c55c && \ pip install . -# Apex bugfix for PyTorch 23.11 container: https://github.com/NVIDIA/apex/pull/1760 +# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 RUN git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 && \ + git checkout b496d85fb88a801d8e680872a12822de310951fd && \ pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ # Transformer Engine 1.2.0 diff --git a/README.rst b/README.rst index 78396d80dc45..44e5df6b7488 100644 --- a/README.rst +++ b/README.rst @@ -326,7 +326,7 @@ To install Apex, run git clone https://github.com/NVIDIA/apex.git cd apex - git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 + git checkout b496d85fb88a801d8e680872a12822de310951fd pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies. diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 8363916b71cb..a2316dabb023 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -440,6 +440,7 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet if is_float8tensor(param): param._reset_caches() param.transpose(update_cache=True) + param._lazy_transpose_cache = True @torch.no_grad() def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None: