diff --git a/Dockerfile b/Dockerfile index 2e6b617087bc..ef97b070b24e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,12 +45,14 @@ RUN apt-get update && \ WORKDIR /workspace/ WORKDIR /tmp/ -# TODO: Remove once this Apex commit (5/12/23) is included in PyTorch -# container + +# Distributed Adam support for multiple dtypes RUN git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout 8b7a1ff183741dd8f9b87e7bafd04cfde99cea28 && \ - pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./ + git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82 && \ + pip3 install -v --no-build-isolation --config-settings --build-option="--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" . + +RUN pip3 install git+https://github.com/timmoon10/TransformerEngine.git@float8tensor_experiments # uninstall stuff from base container RUN pip3 uninstall -y sacrebleu torchtext diff --git a/README.rst b/README.rst index 6fbe9047d0c4..8dc820ed83a1 100644 --- a/README.rst +++ b/README.rst @@ -41,14 +41,14 @@ Introduction ------------ -NVIDIA NeMo is a conversational AI toolkit built for researchers working on automatic speech recognition (ASR), -text-to-speech synthesis (TTS), large language models (LLMs), and +NVIDIA NeMo is a conversational AI toolkit built for researchers working on automatic speech recognition (ASR), +text-to-speech synthesis (TTS), large language models (LLMs), and natural language processing (NLP). -The primary objective of NeMo is to help researchers from industry and academia to reuse prior work (code and pretrained models) +The primary objective of NeMo is to help researchers from industry and academia to reuse prior work (code and pretrained models) and make it easier to create new `conversational AI models `_. -All NeMo models are trained with `Lightning `_ and -training is automatically scalable to 1000s of GPUs. +All NeMo models are trained with `Lightning `_ and +training is automatically scalable to 1000s of GPUs. Additionally, NeMo Megatron LLM models can be trained up to 1 trillion parameters using tensor and pipeline model parallelism. NeMo models can be optimized for inference and deployed for production use-cases with `NVIDIA Riva `_. @@ -57,14 +57,14 @@ State of the Art pretrained NeMo models are freely available on `HuggingFace Hub `NVIDIA NGC `_. These models can be used to transcribe audio, synthesize speech, or translate text in just a few lines of code. -We have extensive `tutorials `_ that +We have extensive `tutorials `_ that can all be run on `Google Colab `_. -For advanced users that want to train NeMo models from scratch or finetune existing NeMo models +For advanced users that want to train NeMo models from scratch or finetune existing NeMo models we have a full suite of `example scripts `_ that support multi-GPU/multi-node training. For scaling NeMo LLM training on Slurm clusters or public clouds, please see the `NVIDIA NeMo Megatron Launcher `_. -The NM launcher has extensive recipes, scripts, utilities, and documentation for training NeMo LLMs and also has an `Autoconfigurator `_ +The NM launcher has extensive recipes, scripts, utilities, and documentation for training NeMo LLMs and also has an `Autoconfigurator `_ which can be used to find the optimal model parallel configuration for training on a specific cluster. Also see our `introductory video `_ for a high level overview of NeMo. @@ -245,8 +245,8 @@ To install Apex, run git clone https://github.com/NVIDIA/apex.git cd apex - git checkout 57057e2fcf1c084c0fcc818f55c0ff6ea1b24ae2 - pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./ + git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82 + pip3 install -v --no-build-isolation --config-settings --build-option="--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" . It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies. @@ -283,7 +283,7 @@ Transformer Engine requires PyTorch to be built with CUDA 11.8. Flash Attention ~~~~~~~~~~~~~~~~~~~~ -Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models or use with attention bias (introduced from position encoding, e.g. Alibi), please install `flash-attn `_. +Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models or use with attention bias (introduced from position encoding, e.g. Alibi), please install `flash-attn `_. .. code-block:: bash @@ -292,7 +292,7 @@ Transformer Engine already supports Flash Attention for GPT models. If you want NLP inference UI ~~~~~~~~~~~~~~~~~~~~ -To launch the inference web UI server, please install the gradio `gradio `_. +To launch the inference web UI server, please install the gradio `gradio `_. .. code-block:: bash diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 74a0201968a5..7b1c032e4e42 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -66,7 +66,7 @@ class MegatronBaseModel(NLPModel): - Initialize the model parallel world for nemo. - Turn on all of the nvidia optimizations. - - If `cfg.tokenizer` is available, it loads the tokenizer and pad the vocab to the + - If `cfg.tokenizer` is available, it loads the tokenizer and pad the vocab to the correct size for tensor model parallelism. - If using distributed optimizer, configure to be compatible with O2 level optimizations and/or model parallelism. @@ -407,9 +407,8 @@ def setup_optimization( optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() if self.with_distributed_adam: - # Allocate contiguous buffers to avoid extra copies + # Allocate contiguous buffer to avoid extra copies optim_kwargs['contiguous_grad_buffer'] = True - optim_kwargs['contiguous_param_buffer'] = True # Make sure optimizer state is in FP32 optim_dtype = torch.float32 @@ -490,9 +489,11 @@ def configure_optimizers(self): if self.with_distributed_adam: # Initialize param buckets if explicitly provided - if hasattr(self, 'distributed_adam_buckets'): + if getattr(self, 'distributed_adam_buckets', None): for bucket in self.distributed_adam_buckets: self._optimizer.init_params_bucket(bucket) + self._optimizer.init_params_bucket(self.parameters()) + if hasattr(self, 'distributed_adam_buckets'): del self.distributed_adam_buckets # Make sure all params are initialized so main grads are @@ -509,7 +510,8 @@ def configure_optimizers(self): self._optimizer.init_params(reversed(no_overlap_params)) # Initialize contiguous parameter buffer - self._optimizer.init_param_buffer() + if self._optimizer.contiguous_param_buffer: + self._optimizer.init_param_buffer() if self._scheduler is None: return self._optimizer diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9c191a4ce078..b6058ee3f002 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -16,6 +16,7 @@ import os import queue import warnings +from contextlib import nullcontext from functools import partial from typing import Any, Dict, Iterator, List, Optional, Union @@ -226,11 +227,19 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), ) else: - self.model = build_model( - model_provider_func=self.model_provider_func, - wrap_with_ddp=False, - virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), - ) + fp8_enabled = cfg.get('fp8', False) and int(os.getenv("NEMO_WITH_FP8_PARAMS", "1")) + make_model_context = nullcontext + if fp8_enabled and HAVE_TE: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=0, interval=1, fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + make_model_context = partial(transformer_engine.pytorch.fp8_model_init, enabled=True) + with make_model_context(): + self.model = build_model( + model_provider_func=self.model_provider_func, + wrap_with_ddp=False, + virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + ) # if we're not using interleaved, then self.model is a module. if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None: @@ -437,10 +446,6 @@ def configure_optimizers(self): [p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)] ) buckets.reverse() - used_params = set() - for bucket in buckets: - used_params.update(bucket) - buckets[-1].extend(p for p in self.parameters() if p not in used_params) self.distributed_adam_buckets = buckets return super().configure_optimizers() diff --git a/nemo/collections/nlp/modules/common/megatron/clip_grads.py b/nemo/collections/nlp/modules/common/megatron/clip_grads.py index a1620931a695..4c38fdd1ef8c 100644 --- a/nemo/collections/nlp/modules/common/megatron/clip_grads.py +++ b/nemo/collections/nlp/modules/common/megatron/clip_grads.py @@ -200,7 +200,7 @@ def clip_grad_norm_distributed_optimizer(optimizer, max_norm, norm_type=2): # - parameter should not be shared # - should not be a replica due to tensor model parallelism params_for_norm = [] - for param in optimizer.parameters(with_fp32_optim_params=True): + for param in optimizer.parameters(): is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if is_not_shared and is_not_tp_duplicate: diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 98dba5423009..2ad44357dc4e 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -835,8 +835,6 @@ def __init__( params_dtype=params_dtype, get_rng_state_tracker=get_rng_state_tracker, fuse_wgrad_accumulation=fuse_wgrad_accumulation, - apply_query_key_layer_scaling=apply_query_key_layer_scaling, - attention_softmax_in_fp32=attention_softmax_in_fp32, seq_length=seq_length, micro_batch_size=micro_batch_size, sequence_parallel=sequence_parallel, diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 706bc48774e3..6f4714323e8e 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -14,31 +14,61 @@ import collections import itertools +from typing import Callable, Dict, Iterable, Optional, Union import torch from apex.contrib.optimizers.distributed_fused_adam import ( DistributedFusedAdam, - _coalescing_manager, - _coalescing_manager_append_work, _disable_pre_forward_hook, + _multi_tensor_copy, ) from megatron.core import parallel_state +# Check if Transformer Engine has FP8 tensor class +HAVE_TE_FP8TENSOR = False +try: + from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 + from transformer_engine.pytorch.float8_tensor import Float8Tensor -def _str_to_dtype(dtype): + HAVE_TE_FP8TENSOR = True +except (ImportError, ModuleNotFoundError): + pass + + +def _str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: if isinstance(dtype, torch.dtype): return dtype name = str(dtype).strip().lower() - if name in ('', 'none'): - return torch.float32 - elif name in ('torch.float32', 'float32', 'float', 'fp32', '32'): - return torch.float32 - elif name in ('torch.float16', 'float16', 'half', 'fp16', '16'): - return torch.float16 - elif name in ('torch.bfloat16', 'bfloat16', 'bf16'): - return torch.bfloat16 - else: - raise ValueError(f'unsupported dtype ({dtype})') + if name.startswith("torch."): + name = name.replace("torch.", "", 1) + if name.startswith("fp"): + name = name.replace("fp", "float", 1) + dtype = dict( + float32=torch.float32, + float=torch.float32, + float64=torch.float64, + double=torch.float64, + float16=torch.float16, + half=torch.float16, + bfloat16=torch.bfloat16, + bf16=torch.bfloat16, + uint8=torch.uint8, + byte=torch.uint8, + int8=torch.int8, + char=torch.int8, + int16=torch.int16, + short=torch.int16, + int32=torch.int32, + int=torch.int32, + int64=torch.int64, + long=torch.int64, + bool=torch.bool, + )[name] + return dtype + + +def _is_fp8_tensor(tensor: torch.Tensor) -> bool: + return HAVE_TE_FP8TENSOR and isinstance(tensor, Float8Tensor) class MegatronDistributedFusedAdam(DistributedFusedAdam): @@ -49,7 +79,12 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam): """ - def __init__(self, params, disable_distributed_parameters=False, **kwargs): + def __init__( + self, + params: Union[Iterable[torch.nn.Parameter], Iterable[dict]], + disable_distributed_parameters: bool = False, + **kwargs, + ): # Initialize process groups if 'process_group' not in kwargs and not parallel_state.is_unitialized(): @@ -72,78 +107,29 @@ def __init__(self, params, disable_distributed_parameters=False, **kwargs): if not isinstance(param_groups[0], dict): param_groups = [{'params': param_groups}] - # Check if explicit FP32 optimizer is needed - self._fp32_optim = None - distopt_param_groups = param_groups - dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32 - grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype - needs_fp32_optimizer = dtype != torch.float32 or grad_sync_dtype != torch.float32 - if needs_fp32_optimizer: - needs_fp32_optimizer = any( - any(getattr(param, '_with_fp32_optimizer', False) for param in param_group['params']) - for param_group in param_groups - ) - if needs_fp32_optimizer: + # Construct distributed optimizer + super().__init__(param_groups, **kwargs) - # Find params that require explicit FP32 optimizer - distopt_param_groups = [] - fp32_param_groups = [] - self._fp32_optim_main_params = collections.OrderedDict() + # Initialize weights that require FP32 grads + if self.dtype != torch.float32 or self.grad_sync_dtype != torch.float32: + fp32_params = [] for param_group in param_groups: - distopt_param_group = param_group.copy() - distopt_param_group['params'] = [] - fp32_param_group = param_group.copy() - fp32_param_group['params'] = [] - for model_param in param_group['params']: - if getattr(model_param, '_with_fp32_optimizer', False): - main_param = model_param.detach().clone().float() - model_param.main_grad = main_param.grad - fp32_param_group['params'].append(main_param) - self._fp32_optim_main_params[model_param] = main_param - else: - distopt_param_group['params'].append(model_param) - distopt_param_groups.append(distopt_param_group) - fp32_param_groups.append(fp32_param_group) - - # Add callback hook so grads accumulate into FP32 buffer - self._fp32_register_post_backward_hooks() - - # Construct explicit FP32 optimizer - adamw_kwargs = {} - for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'): - if name in kwargs: - adamw_kwargs[name] = kwargs[name] - self._fp32_optim = torch.optim.AdamW(fp32_param_groups, **adamw_kwargs) - self._fp32_optim_grad_sync_needed = True + fp32_params.extend( + filter(lambda param: getattr(param, '_with_fp32_optimizer', False), param_group['params']) + ) + if fp32_params: + assert self.dtype == torch.float32, ( + 'Param requires FP32 state, ' f'but optimizer is initialized with {dtype}' + ) + self.init_params_bucket( + fp32_params, grad_sync_dtype=torch.float32, + ) - # Construct distributed optimizer - super().__init__(distopt_param_groups, **kwargs) - - def _fp32_register_post_backward_hooks(self): - """Attach hooks for FP32 gradients""" - - # Helper function to avoid issues with late binding closures - def make_post_backward_hook(param): - def post_backward_hook(*unused): - self._fp32_optim_grad_sync_needed = True - if hasattr(param, 'main_grad'): - with torch.no_grad(): - if param.grad is not None: - param.main_grad += param.grad - param.grad = None - - return post_backward_hook - - # Construct hooks and register with params - self._fp32_grad_accs = [] - for param in self._fp32_optim_main_params.keys(): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - hook = make_post_backward_hook(param) - grad_acc.register_hook(hook) - self._fp32_grad_accs.append(grad_acc) - - def _make_post_backward_hook(self, param, param_group_id, param_id): + def _broadcast_params(self) -> None: + # Assume params have already been synchronized + pass + + def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int) -> Callable: def hook(*unused): if getattr(param, '_pre_forward_hook_is_enabled', False): raise RuntimeError( @@ -166,68 +152,183 @@ def hook(*unused): return hook - def _filter_distopt_params(self, params): - if self._fp32_optim is None: - return params + def init_params( + self, + params: Optional[Iterable[torch.nn.Parameter]] = None, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for parameters + + Initializes FP8 and non-FP8 params separately. + + """ + + # Default cases if params is None: - return None - if isinstance(params, torch.Tensor): + params = self.parameters() + elif isinstance(params, torch.Tensor): params = [params] - return filter(lambda param: param not in self._fp32_optim_main_params, params) - def parameters(self, with_fp32_optim_params=False): - if with_fp32_optim_params and self._fp32_optim is not None: - return itertools.chain(super().parameters(), self._fp32_optim_main_params.keys()) - else: - return super().parameters() + # Ignore parameters that have already been initialized + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return - def init_params(self, params=None): - super().init_params(self._filter_distopt_params(params)) + # Initialize FP8 and non-FP8 tensors separately + if any(_is_fp8_tensor(param) for param in params): + super().init_params( + filter(_is_fp8_tensor, params), param_sync_dtype=torch.uint8, **kwargs, + ) + super().init_params( + params, param_sync_dtype=param_sync_dtype, **kwargs, + ) - def init_params_bucket(self, params): - super().init_params_bucket(self._filter_distopt_params(params)) + def init_params_bucket( + self, params: Iterable[torch.nn.Parameter], param_sync_dtype: Optional[torch.dtype] = None, **kwargs, + ) -> None: + """Initialize optimizer state for parameters in one effective bucket - def try_grad_sync(self, params): - params = self._filter_distopt_params(params) - params = [p for p in params if not getattr(p, '_disable_greedy_grad_copy', False)] - params = [p for p in params if not getattr(p, '_disable_overlap_grad_sync', False)] - for p in params: - self._grad_copy(p) - self._try_start_bucket_grad_sync(params=params) + If any FP8 params are detected, all non-FP8 params are removed + from the bucket and their overlapped grad syncs are disabled. + This assumes that weight matrices are FP8 params and that + non-FP8 params are small (e.g. biases and layer norm params). + + """ - def _try_start_bucket_param_sync(self, params=None): - super()._try_start_bucket_param_sync(self._filter_distopt_params(params)) + # Ignore parameters that have already been initialized + if isinstance(params, torch.Tensor): + params = [params] + params = [param for param in params if "fragments" not in self.state[param]] + if not params: + return + + # Ignore non-FP8 params if there are any FP8 params + if any(_is_fp8_tensor(param) for param in params): + for param in params: + if not _is_fp8_tensor(param): + param._disable_overlap_grad_sync = True + params = filter(_is_fp8_tensor, params) + param_sync_dtype = torch.uint8 + + # Initialize parameter buckets + super().init_params_bucket( + params, param_sync_dtype=param_sync_dtype, **kwargs, + ) + + def _init_param_state( + self, + param: torch.nn.Parameter, + param_group_id: int, + param_id: int, + param_sync_dtype: Optional[torch.dtype] = None, + **kwargs, + ) -> None: + """Initialize optimizer state for a parameter + + Initializing the master weights requires slicing a flattened + view of the param. FP8 tensors do not handle these operations + gracefully, so we hack around it by explicitly casting to + FP32. + + """ + + # Initialize non-FP8 params as usual + if not _is_fp8_tensor(param): + super()._init_param_state( + param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, + ) - def _fp32_optim_grad_sync(self): - if self._fp32_optim is None or not self._fp32_optim_grad_sync_needed: + # Return immediately if already initialized + if "fragments" in self.state[param]: return - for model_param, main_param in self._fp32_optim_main_params.items(): - if model_param.grad is not None: - main_param.grad += model_param.grad.detach() - with _coalescing_manager(self.process_group, self.device, async_ops=True) as cm: - for main_param in self._fp32_optim_main_params.values(): - _coalescing_manager_append_work( - cm, - torch.distributed.all_reduce( - main_param.grad, op=torch.distributed.ReduceOp.AVG, group=self.process_group, async_op=True, - ), + + # Initialize with FP32 copy of param + fp32_param = param.float() + super()._init_param_state( + fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, + ) + self.state[param].update(self.state[fp32_param]) + del self.state[fp32_param] + + @torch.no_grad() + def init_param_buffer(self) -> None: + """Allocate contiguous buffers for param buckets + + For FP8 params, the FP8 data buffer is made a view into a + contiguous buffer. + + """ + + # Make sure all params are initialized + self.contiguous_param_buffer = True + self.init_params() + + # Construct param buffers + buffer_sizes = collections.defaultdict(lambda: 0) + for bucket in self.state["buckets"]: + dtypes = bucket.dtypes() + buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes],) + for dtypes, buffer_size in buffer_sizes.items(): + _, _, param_sync_dtype = dtypes + self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device,) + + # Figure out corresponding positions in params and param buffer + params = list(self.parameters()) + param_flat_views = [] + param_buffer_views = [] + for i, param in enumerate(params): + fragment = self.state[param]["fragments"][0] + bucket_id = fragment.bucket_id + bucket = self.state["buckets"][bucket_id] + param_size = param.numel() + bucket_start, _ = fragment.bucket_range + buffer_offset = bucket.contiguous_buffer_offset + buffer_start = buffer_offset + bucket_start + buffer_end = buffer_start + param_size + param_buffer = self._param_buffers[bucket.dtypes()] + param_buffer_view = param_buffer[buffer_start:buffer_end].detach() + if param_buffer_view.device != param.device: + raise RuntimeError( + "Attempted to change a parameter with device={param.device} " + f"into a buffer view with device={param_buffer_view.device}" ) - cm.wait() - self._fp32_optim_grad_sync_needed = False + if _is_fp8_tensor(param): + param_flat_views.append(param._data.detach().view(-1)) + else: + if param_buffer_view.dtype != param.dtype: + raise RuntimeError( + f"Attempted to change a parameter with dtype={param.dtype} " + f"into a buffer view with dtype={param_buffer_view.dtype}" + ) + param_flat_views.append(param.detach().view(-1)) + param_buffer_views.append(param_buffer_view) + + # Copy values into param buffer + _multi_tensor_copy( + param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Make all params a view into the param buffer + for param, buffer_view in zip(params, param_buffer_views): + if _is_fp8_tensor(param): + param._data = buffer_view.view(param.size()) + else: + param.data = buffer_view.view(param.size()) + + def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None: + def is_grad_copy_enabled(param: torch.nn.Parameter) -> bool: + return not getattr(param, '_disable_greedy_grad_copy', False) and not getattr( + param, '_disable_overlap_grad_sync', False + ) - def zero_grad(self, *args, **kwargs): - super().zero_grad(*args, **kwargs) + params = list(filter(is_grad_copy_enabled, params)) + for p in params: + self._grad_copy(p) + self._try_start_bucket_grad_sync(params=params) - # Reset grads for explicit FP32 optimizer - if self._fp32_optim is not None: - self._fp32_optim_grad_sync_needed = True - self._fp32_optim.zero_grad(set_to_none=False) - for model_param, main_param in self._fp32_optim_main_params.items(): - if main_param.grad is None: - main_param.grad = torch.zeros_like(main_param) - if model_param.grad is not None: - model_param.grad.zero_() - model_param.main_grad = main_param.grad + def zero_grad(self, *args, **kwargs) -> None: + super().zero_grad(*args, **kwargs) # Reset main grads if self.contiguous_grad_buffer: @@ -235,7 +336,9 @@ def zero_grad(self, *args, **kwargs): with _disable_pre_forward_hook(param): param.main_grad = self.grad_buffer_view(param) - def grad_norm(self, parameters=None, norm_type=2.0, force=False): + def grad_norm( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None, norm_type: float = 2.0, force: bool = False, + ) -> torch.Tensor: assert norm_type == 2 if parameters is not None: @@ -246,24 +349,10 @@ def grad_norm(self, parameters=None, norm_type=2.0, force=False): if force or self._grad_norm is None: # Compute norm of local gradients for distributed optimizer - grad_norm_sq = self._local_grad_norm( - parameters=self._filter_distopt_params(parameters), norm_type=norm_type, - ) + grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type) if self.redundant_size > 1: grad_norm_sq /= self.redundant_size - # Compute norm of local gradients for explicit FP32 optimizer - if self._fp32_optim is not None: - self._fp32_optim_grad_sync() - if parameters is None: - for main_param in self._fp32_optim_main_params.values(): - grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size - else: - for model_param in parameters: - if model_param in self._fp32_optim_main_params: - main_param = self._fp32_optim_main_params[model_param] - grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size - # Sum over all procs to get grad norm torch.distributed.all_reduce( grad_norm_sq, op=torch.distributed.ReduceOp.SUM, @@ -273,48 +362,162 @@ def grad_norm(self, parameters=None, norm_type=2.0, force=False): # Use cached grad norm return super().grad_norm() - def step(self, closure=None, *, grad_scaler=None): - - # Apply distributed optimizer - loss = super().step(closure=closure, grad_scaler=grad_scaler) - - if self._fp32_optim is not None: - - # Handle grad scaling - if grad_scaler is not None: - scaler_state = grad_scaler._per_optimizer_states[id(self)] - for _, found_inf in scaler_state['found_inf_per_device'].items(): - if found_inf.item(): - return loss - - # Update learning rate - for distopt_group, fp32_optim_group in zip(self.param_groups, self._fp32_optim.param_groups): - fp32_optim_group['lr'] = distopt_group['lr'] - - # Apply explicit FP32 optimizer - self._fp32_optim_grad_sync() - for main_param in self._fp32_optim_main_params.values(): - main_param.grad *= self._grad_scale - self._fp32_optim.step() - for model_param, main_param in self._fp32_optim_main_params.items(): - model_param.detach().copy_(main_param.detach()) - - return loss - - def state_dict(self, *args, **kwargs): - state_dict = super().state_dict(*args, **kwargs) - if self._fp32_optim is not None and state_dict is not None: - state_dict['fp32_optim'] = self._fp32_optim.state_dict() - state_dict['fp32_optim_fp32_params'] = list(self._fp32_optim_main_params.values()) - return state_dict - - def load_state_dict(self, state_dict): - if self._fp32_optim is not None and 'fp32_optim' in state_dict: - self._fp32_optim.load_state_dict(state_dict['fp32_optim']) - del state_dict['fp32_optim'] - for old_main_param, new_main_param in zip( - self._fp32_optim_main_params.values(), state_dict['fp32_optim_fp32_params'] - ): - old_main_param.copy_(new_main_param.detach()) - del state_dict['fp32_optim_fp32_params'] - return super().load_state_dict(state_dict) + @torch.no_grad() + def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.ParameterFragment]) -> None: + """Update parameter fragments with values from parameter buckets + + For FP8 params, values are copied directly into the FP8 data + buffer. + + """ + + # Figure out corresponding positions in param buckets and params + buffers_in = [] + buffers_out = [] + fragments = list(fragments) + for fragment in fragments: + + # Check if fragment needs to be updated + bucket_id = fragment.bucket_id + bucket_start, bucket_end = fragment.bucket_range + param_start, param_end = fragment.param_range + if param_end <= param_start or bucket_id not in self._params_buckets: + continue + + # Corresponding positions in bucket and param + state_bucket = self.state["buckets"][bucket_id] + param_bucket = self._params_buckets[bucket_id] + param = self.parameter(fragment) + buffer_in = param_bucket.params_bucket[bucket_start:bucket_end] + if _is_fp8_tensor(param): + # Copy into FP8 params's data buffer + assert ( + param_bucket.params_bucket.dtype == torch.uint8 + ), "Expected FP8 params to perform param sync in UINT8" + buffer_out = param._data.view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + elif torch.is_floating_point(buffer_in) and torch.is_floating_point(param): + # Cast between floating-point dtypes + buffer_out = param.detach().view(-1)[param_start:param_end] + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + buffer_out = param.detach().view(-1)[param_start:param_end] + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Copy data from parameter buckets to parameters + _multi_tensor_copy( + buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + @torch.no_grad() + def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None: + """Make sure local shards of parameters are in expected datatypes + + For FP8 params, FP32 values are cast into FP8 using per-param + scaling factors and per-param amaxes are computed and reduced. + + """ + + # Just call base class function if there are no FP8 tensors + num_fp8_params = sum(1 for param in self.parameters() if _is_fp8_tensor(param)) + if num_fp8_params == 0: + super()._check_params_shard_dtypes(params_buckets) + return + + # FP8 scaling factors + amaxes = [] + scales = [] + scale_invs = [] + i = -1 + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + i += 1 + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1)) + scales.append(fp8_meta.scale[fp8_meta_index].view(1)) + scale_invs.append(param._scale_inv.view(1)) + + # Update cached scale-inverses + packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.reciprocal(packed_scales, out=packed_scales) + _multi_tensor_copy( + packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Cast local data to FP8 + fp8_params_shards = dict() + for param in self.parameters(): + if not _is_fp8_tensor(param): + continue + + # FP8 metadata + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + fp8_dtype = param._fp8_dtype + + # Iterate through fragments with local data + for fragment in self.state[param]["fragments"]: + if not fragment.in_local_shard: + continue + shard_start, shard_end = fragment.shard_range + if shard_end <= shard_start: + continue + shard_range = slice(shard_start, shard_end) + + # Get bucket containing fragment + bucket_id = fragment.bucket_id + if bucket_id not in params_buckets: + continue + state_bucket = self.state["buckets"][bucket_id] + param_bucket = params_buckets[bucket_id] + if state_bucket.param_sync_dtype != torch.uint8: + continue + + # Allocate FP8 buffer if needed + if bucket_id not in fp8_params_shards: + fp8_params_shards[bucket_id] = torch.empty_like(param_bucket.params_shard, dtype=torch.uint8) + + # FP8 cast and amax + fp32_fragment = param_bucket.params_shard[shard_range].view(1, -1) + fp8_fragment = fp8_params_shards[bucket_id][shard_range].view(1, -1) + cast_to_fp8( + fp32_fragment, fp8_meta, fp8_meta_index, fp8_dtype, out=fp8_fragment, + ) + + # Update param shards with FP8 buffers + for bucket_id, params_shard in fp8_params_shards.items(): + params_buckets[bucket_id].params_shard = params_shard + + # Reduce amaxes + # Note: Assume each param has a separate amax + packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] + _multi_tensor_copy( + amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + ) + torch.distributed.all_reduce( + packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + ) + _multi_tensor_copy( + packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + ) + + # Handle any remaining dtype conversions + super()._check_params_shard_dtypes(params_buckets)