From b049f83ec79c4235f2c68007a7b05828997e314b Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 7 Sep 2023 16:11:30 -0700 Subject: [PATCH 1/6] fp8 poc usage with megatron-core --- .../language_modeling/megatron_gpt_model.py | 21 ++++++++++++++----- nemo/core/optim/optimizer_with_main_params.py | 2 +- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 1d0e08abe6dd..fb55730d0558 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -231,11 +231,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), ) else: - self.model = build_model( - model_provider_func=self.model_provider_func, - wrap_with_ddp=False, - virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), - ) + fp8_enabled = cfg.get('fp8', False) + fp8_recipe = None + if fp8_enabled and HAVE_TE: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=0, interval=1, fp8_format=transformer_engine.common.recipe.Format.E4M3 + ) + with transformer_engine.pytorch.fp8_autocast( + enabled=fp8_enabled, fp8_recipe=fp8_recipe + ): + self.model = build_model( + model_provider_func=self.model_provider_func, + wrap_with_ddp=False, + virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + ) # if we're not using interleaved, then self.model is a module. if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None: @@ -295,8 +304,10 @@ def get_inference_config(self): def model_provider_func(self, pre_process, post_process): """Model depends on pipeline paralellism.""" if self.mcore_gpt: + from megatron.core.models.gpt.gpt_decoder_spec import get_gpt_decoder_spec model = MCoreGPTModel( config=self.transformer_config, + spec=get_gpt_decoder_spec(), vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), max_sequence_length=self.cfg.get('encoder_seq_length', 512), pre_process=pre_process, diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index 922412f1e8a6..149dff44620a 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -162,7 +162,7 @@ class MainParamsOptimizerWrapper(torch.optim.Optimizer): Arguments: optimizer: base optimizer such as Adam or SGD. fp32_grad_accum: to enable the use of fp32 in gradient accumulation and allreduce. - contiguous_grad_bucket: to enable allocating the master gradients in the + contiguous_grad_bucket: to enable allocating the master gradients in the contiguous memory space to reduce memory fragmentation. async_grad_allreduce: enable asynchronous gradient allreduce that is executed along with the training step backprop. From 48cc1d1775e4ce4d4d9c315f33a08d971b71c90a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 8 Sep 2023 17:42:20 -0700 Subject: [PATCH 2/6] Add FP8 support to distopt Signed-off-by: Tim Moon --- .../language_modeling/megatron_base_model.py | 4 +- .../language_modeling/megatron_gpt_model.py | 6 - nemo/core/optim/distributed_adam.py | 307 +++++++++++++++++- 3 files changed, 308 insertions(+), 9 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index a7b1b9521e3c..ee24d849dcff 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -532,9 +532,11 @@ def configure_optimizers(self): if self.with_distributed_adam: # Initialize param buckets if explicitly provided - if hasattr(self, 'distributed_adam_buckets'): + if getattr(self, 'distributed_adam_buckets', None): for bucket in self.distributed_adam_buckets: self._optimizer.init_params_bucket(bucket) + self._optimizer.init_params_bucket(self.parameters()) + if hasattr(self, 'distributed_adam_buckets'): del self.distributed_adam_buckets # Make sure all params are initialized so main grads are diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index fb55730d0558..ff0c8dc33abe 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -469,12 +469,6 @@ def configure_optimizers(self): [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] ) buckets.reverse() - used_params = set() - for bucket in buckets: - used_params.update(bucket) - remaining_params = [p for p in self.parameters() if p not in used_params] - if remaining_params: - buckets.append(remaining_params) self.distributed_adam_buckets = buckets return super().configure_optimizers() diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 62bba769f652..a1e28c894062 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -14,10 +14,14 @@ import collections import itertools -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Dict, Iterable, Optional, Union import torch -from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam, _disable_pre_forward_hook +from apex.contrib.optimizers.distributed_fused_adam import ( + DistributedFusedAdam, + _disable_pre_forward_hook, + _multi_tensor_copy, +) from megatron.core import parallel_state from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace from megatron.core.dist_checkpointing.mapping import ShardedTensor @@ -27,6 +31,14 @@ optim_state_to_sharding_state, ) +# Check if Transformer Engine has FP8 tensor class +HAVE_TE_FP8TENSOR = False +try: + from transformer_engine.pytorch import Float8Tensor + HAVE_TE_FP8TENSOR = True +except (ImportError, ModuleNotFoundError): + pass + def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: if isinstance(dtype, torch.dtype): @@ -60,6 +72,10 @@ def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: return dtype +def _is_fp8_tensor(tensor: torch.Tensor) -> bool: + return HAVE_TE_FP8TENSOR and isinstance(tensor, Float8Tensor) + + class MegatronDistributedFusedAdam(DistributedFusedAdam): """Wrapper class that supports NeMo-Megatron optimizations @@ -114,6 +130,10 @@ def __init__( fp32_params, grad_sync_dtype=torch.float32, ) + def _broadcast_params(self) -> None: + # Assume params have already been synchronized + pass + def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int,) -> Callable: def hook(*unused): if getattr(param, '_pre_forward_hook_is_enabled', False): @@ -137,6 +157,122 @@ def hook(*unused): return hook + def init_params( + self, + params: Optional[Iterable[torch.nn.Parameter]] = None, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for parameters + + Initializes FP8 and non-FP8 params separately. + + """ + + # Default cases + if params is None: + params = self.parameters() + elif isinstance(params, torch.Tensor): + params = [params] + + # Ignore parameters that have already been initialized + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Initialize FP8 and non-FP8 tensors separately + if any(_is_fp8_tensor(param) for param in params): + super().init_params( + filter(_is_fp8_tensor, params), + param_sync_dtype=torch.uint8, + **kwargs, + ) + super().init_params( + params, + param_sync_dtype=param_sync_dtype, + **kwargs, + ) + + def init_params_bucket( + self, + params: Iterable[torch.nn.Parameter], + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for parameters in one effective bucket + + If any FP8 params are detected, all non-FP8 params are removed + from the bucket and their overlapped grad syncs are disabled. + This assumes that weight matrices are FP8 params and that + non-FP8 params are small (e.g. biases and layer norm params). + + """ + + # Ignore parameters that have already been initialized + if isinstance(params, torch.Tensor): + params = [params] + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Ignore non-FP8 params if there are any FP8 params + if any(_is_fp8_tensor(param) for param in params): + for param in params: + if not _is_fp8_tensor(param): + param._disable_overlap_grad_sync = True + params = filter(_is_fp8_tensor, params) + param_sync_dtype = torch.uint8 + + # Initialize parameter buckets + super().init_params_bucket( + params, + param_sync_dtype=param_sync_dtype, + **kwargs, + ) + + def _init_param_state( + self, + param: torch.nn.Parameter, + param_group_id: int, + param_id: int, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for a parameter + + Initializing the master weights requires slicing a flattened + view of the param. FP8 tensors do not handle these operations + gracefully, so we hack around it by explicitly casting to + FP32. + + """ + + # Initialize non-FP8 params as usual + if not _is_fp8_tensor(param): + super()._init_param_state( + param, + param_group_id, + param_id, + param_sync_dtype=param_sync_dtype, + **kwargs, + ) + + # Return immediately if already initialized + if "fragments" in self.state[param]: + return + + # Initialize with FP32 copy of param + fp32_param = param.float() + super()._init_param_state( + fp32_param, + param_group_id, + param_id, + param_sync_dtype=torch.uint8, + **kwargs, + ) + self.state[param].update(self.state[fp32_param]) + del self.state[fp32_param] + def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None: def is_grad_copy_enabled(param: torch.nn.Parameter) -> bool: return not getattr(param, '_disable_greedy_grad_copy', False) and not getattr( @@ -183,6 +319,173 @@ def grad_norm( # Use cached grad norm return super().grad_norm() + @torch.no_grad() + def _param_copy_fragments( + self, + fragments: Iterable[DistributedFusedAdam.ParameterFragment], + ) -> None: + """Update parameter fragments with values from parameter buckets + + For FP8 params, values are copied directly into the FP8 data + buffer. + + """ + + # Figure out corresponding positions in param buckets and params + buffers_in = [] + buffers_out = [] + for fragment in fragments: + + # Check if fragment needs to be updated + bucket_id = fragment.bucket_id + bucket_start, bucket_end = fragment.bucket_range + param_start, param_end = fragment.param_range + if param_end <= param_start or bucket_id not in self._params_buckets: + continue + + # Corresponding positions in bucket and param + state_bucket = self.state["buckets"][bucket_id] + param_bucket = self._params_buckets[bucket_id] + param = self.parameter(fragment) + buffer_in = param_bucket.params_bucket[bucket_start:bucket_end] + if _is_fp8_tensor(param): + # Copy into FP8 params's data buffer + assert ( + param_bucket.params_bucket.dtype == torch.uint8 + ), "Expected FP8 params to perform param sync in UINT8" + buffer_out = param._data.view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + elif ( + torch.is_floating_point(buffer_in) + and torch.is_floating_point(param) + ): + # Cast between floating-point dtypes + buffer_out = param.detach().view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + buffer_out = param.detach().view(-1)[param_start:param_end] + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Copy data from parameter buckets to parameters + _multi_tensor_copy( + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + + @torch.no_grad() + def _check_params_shard_dtypes( + self, + params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket], + ) -> None: + """Make sure local shards of parameters are in expected datatypes + + For FP8 params, FP32 values are cast into FP8 using per-param + scaling factors and per-param amaxes are computed and reduced. + + """ + + # Peform FP8 casts if needed + num_fp8_params = sum( + 1 for param in self.parameters() if _is_fp8_tensor(param) + ) + if num_fp8_params > 0: + + # Packed buffer for amax reductions + amaxes = torch.zeros( + num_fp8_params, + dtype=torch.float32, + device=self.device, + ) + amax_pos = -1 + + # Loop through FP8 tensors + fp8_params_shards = dict() + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + amax_pos += 1 + + # Loop through fragments with local data + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) + + # Get bucket containing fragment + bucket_id = fragment.bucket_id + if bucket_id not in params_buckets: + continue + state_bucket = self.state["buckets"][bucket_id] + param_bucket = params_buckets[bucket_id] + if state_bucket.param_sync_dtype != torch.uint8: + continue + + # Allocate FP8 buffer if needed + if bucket_id not in fp8_params_shards: + fp8_params_shards[bucket_id] = torch.empty_like( + param_bucket.params_shard, + dtype=torch.uint8, + ) + + # FP8 cast and amax + ### TODO Multi-tensor cast-amax + fp32_fragment = param_bucket.params_shard[shard_range] + fp8_fragment = Float8Tensor.from_float32( + param_bucket.params_shard[shard_range], + param._scale, + param._flavor, + ) + fp8_params_shards[bucket_id][shard_range].copy_( + fp8_fragment._data, + ) + amaxes[amax_pos:amax_pos+1].copy_(fp32_fragment.amax()) + + # Update param shards with FP8 buffers + for bucket_id, params_shard in fp8_params_shards.items(): + params_buckets[bucket_id].params_shard = params_shard + + # Reduce amaxes + torch.distributed.all_reduce( + amaxes, + op=torch.distributed.ReduceOp.MAX, + group=self.distributed_process_group, + ) + + # Unpack amaxes + ### TODO Handle + # buffers_in = [] + # buffers_out = [] + # pos = -1 + # for param in self.parameters(): + # if not _is_fp8_tensor(param): + # continue + # pos += 1 + # buffers_in.append(amaxes[pos:pos+1]) + # buffers_out.append(param._amax) + # _multi_tensor_copy( + # buffers_in, + # buffers_out, + # dummy_overflow_buf=self._dummy_overflow_buf, + # ) + + # Handle any remaining dtype conversions + super()._check_params_shard_dtypes(params_buckets) + def sharded_state_dict(self, model_sharded_state_dict): optimizer_state_dict = self.state_dict() From 5b9db8c074783c42fdd6a407dde63afecd29e0fd Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 11 Sep 2023 10:57:31 -0700 Subject: [PATCH 3/6] Correctly accumulate amax when param is split across buckets Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index a1e28c894062..fa4d91e49aec 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -444,6 +444,7 @@ def _check_params_shard_dtypes( # FP8 cast and amax ### TODO Multi-tensor cast-amax + ### TODO Use updated scale fp32_fragment = param_bucket.params_shard[shard_range] fp8_fragment = Float8Tensor.from_float32( param_bucket.params_shard[shard_range], @@ -453,7 +454,8 @@ def _check_params_shard_dtypes( fp8_params_shards[bucket_id][shard_range].copy_( fp8_fragment._data, ) - amaxes[amax_pos:amax_pos+1].copy_(fp32_fragment.amax()) + amax = torch.maximum(amaxes[amax_pos:amax_pos+1], fp32_fragment.amax()) + amaxes[amax_pos:amax_pos+1].copy_(amax) # Update param shards with FP8 buffers for bucket_id, params_shard in fp8_params_shards.items(): From 5c76317e594146fdbf5b156d8c1448b7d3584d71 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 18 Sep 2023 15:21:54 -0700 Subject: [PATCH 4/6] Debug FP8 casts in distopt Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 166 +++++++++++++++------------- 1 file changed, 90 insertions(+), 76 deletions(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index fa4d91e49aec..ae21a00eddfb 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -35,6 +35,8 @@ HAVE_TE_FP8TENSOR = False try: from transformer_engine.pytorch import Float8Tensor + from transformer_engine.pytorch.fp8 import get_fp8_te_dtype + from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 HAVE_TE_FP8TENSOR = True except (ImportError, ModuleNotFoundError): pass @@ -384,6 +386,17 @@ def _param_copy_fragments( dummy_overflow_buf=self._dummy_overflow_buf, ) + # Precompute transposes + ### TODO Optimized transpose kernel + for fragment in fragments: + param = self.parameter(fragment) + if _is_fp8_tensor(param): + param._transpose = None + for fragment in fragments: + param = self.parameter(fragment) + if _is_fp8_tensor(param): + param.transpose() + @torch.no_grad() def _check_params_shard_dtypes( self, @@ -396,94 +409,95 @@ def _check_params_shard_dtypes( """ - # Peform FP8 casts if needed + # Just call base class function if there are no FP8 tensors num_fp8_params = sum( 1 for param in self.parameters() if _is_fp8_tensor(param) ) - if num_fp8_params > 0: + if num_fp8_params == 0: + super()._check_params_shard_dtypes(params_buckets) + return - # Packed buffer for amax reductions - amaxes = torch.zeros( - num_fp8_params, - dtype=torch.float32, - device=self.device, + # Iterate through FP8 tensors + fp8_params_shards = dict() + amaxes = [] + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + + # FP8 scaling factors + fp8_meta = param.fp8_meta_view["scaling_fwd"] + fp8_meta_index = param.gemm_index + fp8_dtype = get_fp8_te_dtype( + param.fp8_meta_view["recipe"], + fprop_tensor=True, ) - amax_pos = -1 + fp8_meta.scale_inv[fp8_meta_index] = 1 / fp8_meta.scale[fp8_meta_index] + param._scale_inv_cache = fp8_meta.scale_inv[fp8_meta_index] + amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) - # Loop through FP8 tensors - fp8_params_shards = dict() - for param in self.parameters(): - if not _is_fp8_tensor(param): + # Iterate through fragments with local data + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: continue - amax_pos += 1 - - # Loop through fragments with local data - for fragment in self.state[param]["fragments"]: - if not fragment.in_local_shard: - continue - shard_start, shard_end = fragment.shard_range - if shard_end <= shard_start: - continue - shard_range = slice(shard_start, shard_end) - - # Get bucket containing fragment - bucket_id = fragment.bucket_id - if bucket_id not in params_buckets: - continue - state_bucket = self.state["buckets"][bucket_id] - param_bucket = params_buckets[bucket_id] - if state_bucket.param_sync_dtype != torch.uint8: - continue - - # Allocate FP8 buffer if needed - if bucket_id not in fp8_params_shards: - fp8_params_shards[bucket_id] = torch.empty_like( - param_bucket.params_shard, - dtype=torch.uint8, - ) + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) - # FP8 cast and amax - ### TODO Multi-tensor cast-amax - ### TODO Use updated scale - fp32_fragment = param_bucket.params_shard[shard_range] - fp8_fragment = Float8Tensor.from_float32( - param_bucket.params_shard[shard_range], - param._scale, - param._flavor, - ) - fp8_params_shards[bucket_id][shard_range].copy_( - fp8_fragment._data, + # Get bucket containing fragment + bucket_id = fragment.bucket_id + if bucket_id not in params_buckets: + continue + state_bucket = self.state["buckets"][bucket_id] + param_bucket = params_buckets[bucket_id] + if state_bucket.param_sync_dtype != torch.uint8: + continue + + # Allocate FP8 buffer if needed + if bucket_id not in fp8_params_shards: + fp8_params_shards[bucket_id] = torch.empty_like( + param_bucket.params_shard, + dtype=torch.uint8, ) - amax = torch.maximum(amaxes[amax_pos:amax_pos+1], fp32_fragment.amax()) - amaxes[amax_pos:amax_pos+1].copy_(amax) - # Update param shards with FP8 buffers - for bucket_id, params_shard in fp8_params_shards.items(): - params_buckets[bucket_id].params_shard = params_shard + # FP8 cast and amax + ### TODO Multi-tensor cast-amax + fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) + fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) + cast_to_fp8( + fp32_fragment, + fp8_meta, + fp8_meta_index, + fp8_dtype, + out=fp8_fragment, + ) - # Reduce amaxes - torch.distributed.all_reduce( - amaxes, - op=torch.distributed.ReduceOp.MAX, - group=self.distributed_process_group, - ) + # Update param shards with FP8 buffers + for bucket_id, params_shard in fp8_params_shards.items(): + params_buckets[bucket_id].params_shard = params_shard - # Unpack amaxes - ### TODO Handle - # buffers_in = [] - # buffers_out = [] - # pos = -1 - # for param in self.parameters(): - # if not _is_fp8_tensor(param): - # continue - # pos += 1 - # buffers_in.append(amaxes[pos:pos+1]) - # buffers_out.append(param._amax) - # _multi_tensor_copy( - # buffers_in, - # buffers_out, - # dummy_overflow_buf=self._dummy_overflow_buf, - # ) + # Reduce amaxes + packed_amaxes = torch.zeros( + num_fp8_params, + dtype=torch.float32, + device=self.device, + ) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))] + _multi_tensor_copy( + amaxes, + packed_amax_views, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.distributed.all_reduce( + packed_amaxes, + op=torch.distributed.ReduceOp.MAX, + group=self.distributed_process_group, + ) + _multi_tensor_copy( + packed_amax_views, + amaxes, + dummy_overflow_buf=self._dummy_overflow_buf, + ) # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) From d85ab40023d7f8cf6a9eae9872970b6ced01a6f9 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 19 Sep 2023 21:16:04 -0700 Subject: [PATCH 5/6] Optimize distopt handling of FP8 scaling factors Signed-off-by: Tim Moon --- nemo/core/optim/distributed_adam.py | 43 ++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index ae21a00eddfb..c2f14eacd2b6 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -336,6 +336,7 @@ def _param_copy_fragments( # Figure out corresponding positions in param buckets and params buffers_in = [] buffers_out = [] + fragments = list(fragments) for fragment in fragments: # Check if fragment needs to be updated @@ -387,7 +388,6 @@ def _param_copy_fragments( ) # Precompute transposes - ### TODO Optimized transpose kernel for fragment in fragments: param = self.parameter(fragment) if _is_fp8_tensor(param): @@ -417,23 +417,47 @@ def _check_params_shard_dtypes( super()._check_params_shard_dtypes(params_buckets) return - # Iterate through FP8 tensors - fp8_params_shards = dict() + # FP8 scaling factors amaxes = [] + scales = [] + scale_invs = torch.empty( + num_fp8_params, + dtype=torch.float32, + device=self.device, + ) + i = -1 + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + i += 1 + fp8_meta = param.fp8_meta_view["scaling_fwd"] + fp8_meta_index = param.gemm_index + amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) + scales.append(fp8_meta.scale[fp8_meta_index].view(1)) + param._scale_inv_cache = scale_invs[i] + + # Update cached scale-inverses + scale_inv_views = [scale_invs[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + scales, + scale_inv_views, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.reciprocal(scale_invs, out=scale_invs) + + # Cast local data to FP8 + fp8_params_shards = dict() for param in self.parameters(): if not _is_fp8_tensor(param): continue - # FP8 scaling factors + # FP8 metadata fp8_meta = param.fp8_meta_view["scaling_fwd"] fp8_meta_index = param.gemm_index fp8_dtype = get_fp8_te_dtype( param.fp8_meta_view["recipe"], fprop_tensor=True, ) - fp8_meta.scale_inv[fp8_meta_index] = 1 / fp8_meta.scale[fp8_meta_index] - param._scale_inv_cache = fp8_meta.scale_inv[fp8_meta_index] - amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) # Iterate through fragments with local data for fragment in self.state[param]["fragments"]: @@ -461,7 +485,6 @@ def _check_params_shard_dtypes( ) # FP8 cast and amax - ### TODO Multi-tensor cast-amax fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) cast_to_fp8( @@ -477,12 +500,12 @@ def _check_params_shard_dtypes( params_buckets[bucket_id].params_shard = params_shard # Reduce amaxes - packed_amaxes = torch.zeros( + packed_amaxes = torch.empty( num_fp8_params, dtype=torch.float32, device=self.device, ) - packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))] + packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( amaxes, packed_amax_views, From 3d2272972035157baa68456b3442ac9b954b2b7d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Sep 2023 20:03:17 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../language_modeling/megatron_gpt_model.py | 5 +- nemo/core/optim/distributed_adam.py | 101 ++++-------------- 2 files changed, 25 insertions(+), 81 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ff0c8dc33abe..dbd206e6b2a2 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -237,9 +237,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): fp8_recipe = transformer_engine.common.recipe.DelayedScaling( margin=0, interval=1, fp8_format=transformer_engine.common.recipe.Format.E4M3 ) - with transformer_engine.pytorch.fp8_autocast( - enabled=fp8_enabled, fp8_recipe=fp8_recipe - ): + with transformer_engine.pytorch.fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): self.model = build_model( model_provider_func=self.model_provider_func, wrap_with_ddp=False, @@ -305,6 +303,7 @@ def model_provider_func(self, pre_process, post_process): """Model depends on pipeline paralellism.""" if self.mcore_gpt: from megatron.core.models.gpt.gpt_decoder_spec import get_gpt_decoder_spec + model = MCoreGPTModel( config=self.transformer_config, spec=get_gpt_decoder_spec(), diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index c2f14eacd2b6..5c05010da7e3 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -35,8 +35,9 @@ HAVE_TE_FP8TENSOR = False try: from transformer_engine.pytorch import Float8Tensor - from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 + from transformer_engine.pytorch.fp8 import get_fp8_te_dtype + HAVE_TE_FP8TENSOR = True except (ImportError, ModuleNotFoundError): pass @@ -185,21 +186,14 @@ def init_params( # Initialize FP8 and non-FP8 tensors separately if any(_is_fp8_tensor(param) for param in params): super().init_params( - filter(_is_fp8_tensor, params), - param_sync_dtype=torch.uint8, - **kwargs, + filter(_is_fp8_tensor, params), param_sync_dtype=torch.uint8, **kwargs, ) super().init_params( - params, - param_sync_dtype=param_sync_dtype, - **kwargs, + params, param_sync_dtype=param_sync_dtype, **kwargs, ) def init_params_bucket( - self, - params: Iterable[torch.nn.Parameter], - param_sync_dtype: Optional[torch.dtype] = None, - **kwargs, + self, params: Iterable[torch.nn.Parameter], param_sync_dtype: Optional[torch.dtype] = None, **kwargs, ) -> None: """Initialize optimizer state for parameters in one effective bucket @@ -227,9 +221,7 @@ def init_params_bucket( # Initialize parameter buckets super().init_params_bucket( - params, - param_sync_dtype=param_sync_dtype, - **kwargs, + params, param_sync_dtype=param_sync_dtype, **kwargs, ) def _init_param_state( @@ -252,11 +244,7 @@ def _init_param_state( # Initialize non-FP8 params as usual if not _is_fp8_tensor(param): super()._init_param_state( - param, - param_group_id, - param_id, - param_sync_dtype=param_sync_dtype, - **kwargs, + param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, ) # Return immediately if already initialized @@ -266,11 +254,7 @@ def _init_param_state( # Initialize with FP32 copy of param fp32_param = param.float() super()._init_param_state( - fp32_param, - param_group_id, - param_id, - param_sync_dtype=torch.uint8, - **kwargs, + fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, ) self.state[param].update(self.state[fp32_param]) del self.state[fp32_param] @@ -322,10 +306,7 @@ def grad_norm( return super().grad_norm() @torch.no_grad() - def _param_copy_fragments( - self, - fragments: Iterable[DistributedFusedAdam.ParameterFragment], - ) -> None: + def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.ParameterFragment],) -> None: """Update parameter fragments with values from parameter buckets For FP8 params, values are copied directly into the FP8 data @@ -355,14 +336,11 @@ def _param_copy_fragments( # Copy into FP8 params's data buffer assert ( param_bucket.params_bucket.dtype == torch.uint8 - ), "Expected FP8 params to perform param sync in UINT8" + ), "Expected FP8 params to perform param sync in UINT8" buffer_out = param._data.view(-1)[param_start:param_end] buffers_in.append(buffer_in) buffers_out.append(buffer_out) - elif ( - torch.is_floating_point(buffer_in) - and torch.is_floating_point(param) - ): + elif torch.is_floating_point(buffer_in) and torch.is_floating_point(param): # Cast between floating-point dtypes buffer_out = param.detach().view(-1)[param_start:param_end] buffers_in.append(buffer_in) @@ -382,9 +360,7 @@ def _param_copy_fragments( # Copy data from parameter buckets to parameters _multi_tensor_copy( - buffers_in, - buffers_out, - dummy_overflow_buf=self._dummy_overflow_buf, + buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, ) # Precompute transposes @@ -398,10 +374,7 @@ def _param_copy_fragments( param.transpose() @torch.no_grad() - def _check_params_shard_dtypes( - self, - params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket], - ) -> None: + def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket],) -> None: """Make sure local shards of parameters are in expected datatypes For FP8 params, FP32 values are cast into FP8 using per-param @@ -410,9 +383,7 @@ def _check_params_shard_dtypes( """ # Just call base class function if there are no FP8 tensors - num_fp8_params = sum( - 1 for param in self.parameters() if _is_fp8_tensor(param) - ) + num_fp8_params = sum(1 for param in self.parameters() if _is_fp8_tensor(param)) if num_fp8_params == 0: super()._check_params_shard_dtypes(params_buckets) return @@ -420,11 +391,7 @@ def _check_params_shard_dtypes( # FP8 scaling factors amaxes = [] scales = [] - scale_invs = torch.empty( - num_fp8_params, - dtype=torch.float32, - device=self.device, - ) + scale_invs = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device,) i = -1 for param in self.parameters(): if not _is_fp8_tensor(param): @@ -439,9 +406,7 @@ def _check_params_shard_dtypes( # Update cached scale-inverses scale_inv_views = [scale_invs[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - scales, - scale_inv_views, - dummy_overflow_buf=self._dummy_overflow_buf, + scales, scale_inv_views, dummy_overflow_buf=self._dummy_overflow_buf, ) torch.reciprocal(scale_invs, out=scale_invs) @@ -454,10 +419,7 @@ def _check_params_shard_dtypes( # FP8 metadata fp8_meta = param.fp8_meta_view["scaling_fwd"] fp8_meta_index = param.gemm_index - fp8_dtype = get_fp8_te_dtype( - param.fp8_meta_view["recipe"], - fprop_tensor=True, - ) + fp8_dtype = get_fp8_te_dtype(param.fp8_meta_view["recipe"], fprop_tensor=True,) # Iterate through fragments with local data for fragment in self.state[param]["fragments"]: @@ -479,20 +441,13 @@ def _check_params_shard_dtypes( # Allocate FP8 buffer if needed if bucket_id not in fp8_params_shards: - fp8_params_shards[bucket_id] = torch.empty_like( - param_bucket.params_shard, - dtype=torch.uint8, - ) + fp8_params_shards[bucket_id] = torch.empty_like(param_bucket.params_shard, dtype=torch.uint8,) # FP8 cast and amax fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) cast_to_fp8( - fp32_fragment, - fp8_meta, - fp8_meta_index, - fp8_dtype, - out=fp8_fragment, + fp32_fragment, fp8_meta, fp8_meta_index, fp8_dtype, out=fp8_fragment, ) # Update param shards with FP8 buffers @@ -500,26 +455,16 @@ def _check_params_shard_dtypes( params_buckets[bucket_id].params_shard = params_shard # Reduce amaxes - packed_amaxes = torch.empty( - num_fp8_params, - dtype=torch.float32, - device=self.device, - ) + packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device,) packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - amaxes, - packed_amax_views, - dummy_overflow_buf=self._dummy_overflow_buf, + amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, ) torch.distributed.all_reduce( - packed_amaxes, - op=torch.distributed.ReduceOp.MAX, - group=self.distributed_process_group, + packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, ) _multi_tensor_copy( - packed_amax_views, - amaxes, - dummy_overflow_buf=self._dummy_overflow_buf, + packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, ) # Handle any remaining dtype conversions