diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index aea66b257f..f179569251 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -35,6 +35,8 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.fp8_autocast +.. autoapifunction:: transformer_engine.pytorch.fp8_model_init + .. autoapifunction:: transformer_engine.pytorch.checkpoint .. autoapifunction:: transformer_engine.pytorch.onnx_export diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 268a534a82..54ba2a09c0 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py +pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py new file mode 100644 index 0000000000..dc48c886cf --- /dev/null +++ b/tests/pytorch/test_float8tensor.py @@ -0,0 +1,318 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections.abc import Iterable +from typing import Any, Dict, List, Tuple, Union + +import pytest +import torch + +import transformer_engine.common.recipe +import transformer_engine.pytorch as te +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine_extensions as tex + +# PyTorch tensor dtypes +_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] +# TE FP8 dtypes +_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + +# Numerical tolerances with FP8 types +_tols: Dict[tex.DType, Dict[str, float]] = { + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 +} + +def _to_list(x: Union[Iterable, Any]) -> List: + """Convert to list if iterable, otherwise put in singleton list""" + if isinstance(x, Iterable): + return list(x) + else: + return [x] + +# Types that can be interpreted as tensor dims +DimsType = Union[Iterable[int], int] + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFloat8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_constructor( + self, + dims: DimsType = 1, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale_inv: float = 0.375, + dtype: torch.dtype = torch.float32, + ) -> None: + """Call constructor and perform sanity checks""" + dims = _to_list(dims) + tensor = Float8Tensor( + data=torch.zeros(dims, device="cuda", dtype=torch.uint8), + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.full([1], scale_inv), + dtype=dtype, + ) + assert list(tensor.size()) == dims, "Incorrect dims" + assert tensor.dtype == dtype, "Incorrect nominal dtype" + assert tensor.is_cuda, "Incorrect device" + + def _test_quantize_dequantize( + self, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + dims: DimsType = 23, + ) -> None: + """Check numerical error when casting to FP8 and back""" + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_fp8 = x_fp8.from_float8().cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + def test_quantize_dequantize_dtypes( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + ) -> None: + self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype) + + @pytest.mark.parametrize("scale", [0.375, 1, 3.5]) + def test_quantize_dequantize_scales(self, scale: float) -> None: + self._test_quantize_dequantize(scale=scale) + + @pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]]) + def test_quantize_dequantize_dims(self, dims: DimsType) -> None: + self._test_quantize_dequantize(dims=dims) + + def test_fp8_meta( + self, + dtype: torch.dtype = torch.float32, + dims: DimsType = 23, + ) -> None: + """Construct Float8Tensor using FP8 metadata and perform basic checks""" + + # Get FP8 metadata from linear module + fp8_dtype = tex.DType.kFloat8E4M3 + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + module = te.Linear(32, 32) + _ = module(torch.zeros([8, 32], device="cuda")) + fp8_meta = module.fp8_meta + fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + + # Make Float8Tensor + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + x_ref = x_fp8.from_float8() + assert list(x_fp8.size()) == dims, "Incorrect dims" + assert x_fp8.dtype == dtype, "Incorrect nominal dtype" + assert x_fp8.is_cuda, "Incorrect device" + assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" + + # Change FP8 metadata scale + fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2 + fp8_meta[fp8_meta_key].scale_inv.fill_(123) + + # Check results + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + with pytest.raises(AssertionError): + # Make sure we are not trivially passing the test + torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) + + # Check if scaling factor is updated after in-place ops + x_fp8 += 0 + fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4 + fp8_meta[fp8_meta_key].scale_inv.fill_(321) + assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv" + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + y = x_fp8.detach() + y += 0 + assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv" + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + + def test_basic_ops( + self, + dims: DimsType = 23, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test basic out-of-place ops""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + y_fp8 = Float8Tensor.to_float8( + y_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + y_ref = y_fp8.from_float8() + + # Exact operations + torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0) + torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0) + + # Operations with numerical error + tols = _tols[fp8_dtype] + torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols) + torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols) + torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols) + torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols) + torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols) + torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols) + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) + + def test_inplace_ops( + self, + dims: DimsType = 23, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test in-place ops""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + y_fp8 = Float8Tensor.to_float8( + y_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + y_ref = y_fp8.from_float8() + + # In-place operations + tols = _tols[fp8_dtype] + x_fp8 += y_ref + x_ref += y_ref + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + x_fp8 -= y_fp8 + x_ref -= y_fp8 + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + x_fp8 *= 2 + x_ref *= 2 + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + + # Make sure we are not trivially passing tests + x_ref += 123 + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, x_ref, **tols) + + @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) + @pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)]) + def test_transpose( + self, + dims: DimsType, + transpose_dims: Tuple[int, int], + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 1, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test transpose""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + + # Perform transpose + y_fp8 = x_fp8.transpose(*transpose_dims) + y_ref = x_ref.transpose(*transpose_dims) + + # Check results + tols = dict(rtol=0, atol=0) + torch.testing.assert_close(y_fp8, y_ref, **tols) + + # Make sure we are not trivially passing the test + if transpose_dims[0] != transpose_dims[1]: + with pytest.raises(AssertionError): + torch.testing.assert_close( + y_fp8, + x_ref, + **tols, + ) + + # Check transpose caching + if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, update_cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, update_cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, update_cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 02fb63e71f..474f0a95b9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,7 +12,7 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -339,7 +339,7 @@ def forward( return x -def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -354,24 +354,26 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, - params_dtype=dtype, + with fp8_model_init(enabled=fp8 and fp8_model_params): + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + fuse_qkv_params=True, + ) + .cuda() ) - .cuda() - ) te_inp_hidden_states = torch.randn( config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True @@ -400,18 +402,19 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_selective_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute) -def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -426,7 +429,8 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( + with fp8_model_init(enabled=fp8 and fp8_model_params): + block = ( TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -441,9 +445,10 @@ def get_dummy_cuda_rng_tracker(): output_layernorm=False, get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, + fuse_qkv_params=True, ) .cuda() - ) + ) te_inp_hidden_states = torch.randn( config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True @@ -483,14 +488,15 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_full_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute) @@ -871,6 +877,7 @@ def test_linear_accuracy(dtype, bs, model): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -911,6 +918,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps): else: assert_allclose(te_outputs[0], torch_outputs[0], 2e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1110,3 +1118,72 @@ def test_gpt_cuda_graph(dtype, bs, model): assert_allclose(out, graphed_out, 1e-3) assert_allclose(params, graphed_params, 1e-3) assert_allclose(grads, graphed_grads, 1e-3) + + +def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): + reset_rng_states() + FP8GlobalStateManager.reset() + + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + + with fp8_model_init(enabled=fp8_model_params): + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + fuse_qkv_params=True, + ) + .cuda() + ) + + te_inp_hidden_states = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + te_inp_hidden_states.retain_grad() + te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + + with fp8_autocast(enabled=True): + te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) + loss = te_out.sum() + loss.backward() + torch.cuda.synchronize() + + outputs = [te_out, te_inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_gpt_fp8_parameters(dtype, bs, model): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs[model] + + outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) + outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) + assert_all_equal(outputs, outputs_fp8_params) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 4774cd39ab..dd50f15e43 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -147,7 +147,7 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): """Initialize the FP8 quantization scales in module""" NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. nb_total_scales = num_gemms * NB_SCALES_PER_GEMM - module.fp8_init(num_gemms) + module.init_fp8_metadata(num_gemms) module.fp8_meta["scaling_fwd"].scale = torch.ones( nb_total_scales, dtype=torch.float32, device="cuda") / scale module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index f35b60ede2..2732db6ad9 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -16,7 +16,7 @@ import torch import transformer_engine.pytorch as te import transformer_engine_extensions as tex -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8 +from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -93,7 +93,7 @@ def forward(self, inp, weight): model_in = Test_TE_Export(precision, True) with te.fp8_autocast(enabled=True): - model_in.fp8_init() + model_in.init_fp8_metadata() # scaling fwd model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 8ff601f6f1..b29853a3a7 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -13,6 +13,7 @@ from .attention import MultiheadAttention from .transformer import TransformerLayer from .fp8 import fp8_autocast +from .fp8 import fp8_model_init from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index abc3936e25..1d93d03f3f 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -83,14 +83,16 @@ def initialize_affine_weight_gpu( weight: torch.Tensor, init_method: Callable, get_rng_state_tracker: Callable, - partition_dim: int, + partition_dim: int = 0, stride: int = 1, + set_tp_attributes: bool = True, ) -> None: """Initialize affine weight for model parallel on GPU.""" - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) + if set_tp_attributes: + set_tensor_model_parallel_attributes( + tensor=weight, is_parallel=True, dim=partition_dim, stride=stride + ) if get_rng_state_tracker is None: init_method(weight) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py new file mode 100644 index 0000000000..1868bb4ed2 --- /dev/null +++ b/transformer_engine/pytorch/float8_tensor.py @@ -0,0 +1,689 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from typing import Any, Dict, Optional + +import torch +from torch.utils._pytree import tree_map +import transformer_engine_extensions as tex + +from .constants import TE_DType +from .fp8 import FP8GlobalStateManager + + +aten = torch.ops.aten +c10d = torch.ops.c10d + + +def _make_fp8_attr_property_funcs(name: str) -> Any: + """Make accessors for an FP8 attribute + + We store FP8 attributes in a dictionary so we can share them + between tensors with the same data, e.g. detached tensors. For + convenience, we also expose them as property attributes. This + function creates the accessors for property attributes. + + Parameters + ---------- + name: str + Key in dictionary of FP8 attributes + + """ + def get_func(self) -> Any: + return self._fp8_attrs[name] + def set_func(self, value: Any) -> None: + self._fp8_attrs[name] = value + def del_func(self) -> None: + del self._fp8_attrs[name] + return dict(fget=get_func, fset=set_func, fdel=del_func) + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + if dtype is None: + dtype = tensor.dtype + data = tensor._data.contiguous().view(1,-1).detach() + out = tex.cast_from_fp8( + data, + tensor._scale_inv, + tensor._fp8_dtype, + TE_DType[dtype], + ) + out = out.view(tensor.size()) + return out + + @staticmethod + def backward(ctx, grad): + # Assume that we want gradients in full precision + return grad, None + + +class _ToFloat8Func(torch.autograd.Function): + """Cast to FP8 from other dtype""" + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + ): + + # Manually compute scale-inverse if needed + if scale is not None and scale_inv is None: + if isinstance(scale, torch.Tensor): + scale_inv = scale.reciprocal() + else: + scale_inv = 1 / scale + + # Extract data from FP8 meta tensors if provided + if fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=fp8_meta_forward, + ) + if fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + if scale is None: + scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] + if amax is None: + amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + if scale_inv is None: + scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] + scale_inv = scale_inv.detach().view(1).clone() + + # Check input tensor + tensor = tensor.contiguous().cuda().detach() + if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): + tensor = tensor.float() + + # Check scale + if not isinstance(scale, torch.Tensor): + if scale is None: + scale = 1 + scale = torch.full( + [1], + scale, + dtype=torch.float32, + device=tensor.device, + ) + if scale.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale tensor" + ) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + # Check scale-inverse + if scale_inv is None: + scale_inv = scale.reciprocal() + scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) + + # Check amax + if amax is None: + amax = torch.empty_like(scale) + if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): + raise ValueError( + "Attempted to initialize Float8Tensor with invalid amax tensor" + ) + + # Cast data to FP8 + data = tex.cast_to_fp8( + tensor.view(1,-1), + scale, + amax, + scale_inv, + fp8_dtype, + ) + data = data.view(tensor.size()) + + # Construct FP8 tensor + return Float8Tensor( + data=data, + fp8_meta=fp8_meta, + fp8_meta_forward=fp8_meta_forward, + fp8_meta_index=fp8_meta_index, + fp8_dtype=fp8_dtype, + fp8_scale_inv=scale_inv, + dtype=tensor.dtype, + ) + + @staticmethod + def backward(ctx, grad): + # Assume that we want gradients in full precision + return grad, None, None, None, None, None, None, None + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + + # Return input tensor if constructor kwargs are not provided + ctx.input_dtype = tensor.dtype + if init_kwargs is None: + return tensor + + # Construct new tensor if constructor kwargs are provided + default_kwargs = dict( + data=tensor._data, + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in init_kwargs: + init_kwargs[key] = val + return Float8Tensor(**init_kwargs) + + @staticmethod + def backward(ctx, grad): + return grad.to(ctx.input_dtype), None + + +class Float8Tensor(torch.Tensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_attrs: dict, optional + FP8 metadata, primarily managed by Float8Tensor. If + provided, all other FP8 configuration is ignored. + fp8_meta: dict, optional + FP8 metadata object, primarily managed by TE modules. + fp8_meta_forward: bool, default = `True` + Whether to access the FP8 metadata for the + forward pass. Ignored if fp8_meta is not + provided. + fp8_meta_index: int, optional + Index to access in FP8 meta tensors. Required if + fp8_meta is provided and otherwise ignored. + fp8_dtype: transformer_engine_extensions.DType, tex.DType.kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __new__( + cls, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_scale_inv: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + ): + + # Check that data buffer is valid + if data.element_size() != 1: + raise ValueError( + "Float8Tensor requires data buffer with 8-bit dtype " + f"(got dtype={data.dtype})" + ) + if data.requires_grad: + raise ValueError( + "Float8Tensor requires non-differentiable data buffer" + ) + data = data.cuda() + + # Initialize tensor object + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data: torch.Tensor = data + + # Initialize dict of class attributes + # Note: We store FP8 attributes in a dictionary so we can + # share them between tensors with the same data, e.g. detached + # tensors. + self._fp8_attrs: dict = {} + if fp8_attrs is not None: + self._fp8_attrs = fp8_attrs + return self + + # FP8 meta tensors + if fp8_meta is not None and fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta + self._fp8_meta_forward: bool = fp8_meta_forward + self._fp8_meta_index: Optional[int] = fp8_meta_index + + # FP8 dtype + assert ( + fp8_dtype in (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2) + ), f"Unsupported fp8_dtype {fp8_dtype}." + self._fp8_dtype: tex.DType = fp8_dtype + + # Cached transpose + self._transpose: Optional[Float8Tensor] = None + + # FP8 scale-inverse + self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + if self._scale_inv is None and self._fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + self._scale_inv = scale_inv.detach().view(1).clone() + if self._scale_inv is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying scale-inverse" + ) + if not isinstance(self._scale_inv, torch.Tensor): + self._scale_inv = torch.full( + [1], + self._scale_inv, + dtype=torch.float32, + device=self._data.device, + ) + if self._scale_inv.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" + ) + self._scale_inv = self._scale_inv.to( + device=self._data.device, + dtype=torch.float32, + ) + + return self + + @classmethod + def make_like( + cls, + tensor: Float8Tensor, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Float8Tensor: + """Use attributes of a Float8Tensor to create another Float8Tensor + + See constructor for list of keyword arguments. + + """ + default_kwargs = dict( + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in kwargs: + kwargs[key] = val + return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) + + def __repr__(self): + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.from_float8(dtype=self.dtype)}" + ")" + ) + + def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ + return _FromFloat8Func.apply(self, dtype) + + @classmethod + def to_float8( + cls, + tensor: torch.Tensor, + *, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + ): + """Construct Float8Tensor from plain PyTorch tensor""" + return _ToFloat8Func.apply( + tensor, + fp8_meta, + fp8_meta_forward, + fp8_meta_index, + fp8_dtype, + scale, + amax, + scale_inv, + ) + + def float(self) -> torch.Tensor: + return self.from_float8(dtype=torch.float32) + + def bfloat16(self) -> torch.Tensor: + return self.from_float8(dtype=torch.bfloat16) + + def half(self) -> torch.Tensor: + return self.from_float8(dtype=torch.float16) + + def cpu(self) -> torch.Tensor: + return self.from_float8().cpu() + + def clone(self) -> Float8Tensor: + return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) + + def expand_as(self, other: torch.Tensor): + if other is self: + # Note: expand_as is hackily used to create dummy autograd nodes + # and access the backward graph (see + # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). + # We equally hackily add a dummy function to handle this + # case. + return _IdentityFunc.apply(self) + return super().expand_as(other) + + def _transpose_no_cache(self) -> torch.Tensor: + """ + Swap tensor dimensions + + For basic 2D matrix transposes, an optimized transpose kernel + is applied and a Float8Tensor is returned. + """ + + # Use optimized kernel for basic 2D transpose + # TODO Support differentiation # pylint: disable=fixme + return Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous().detach(), + self._fp8_dtype, + ), + ) + + def transpose( + self, + dim0: int = 0, + dim1: int = 1, + *, + update_cache: Optional[bool] = None, + ) -> torch.Tensor: + """ + Swap tensor dimensions + + For basic 2D matrix transposes, an optimized transpose kernel + is applied and a Float8Tensor is returned. + + Parameters + ---------- + dim0: int, default = 0 + The first dimension to be transposed + dim1: int, default = 1 + The second dimension to be transposed + update_cache: Optional[bool], default = None + If set to `True`, the result is computed and stored in a cache. + If set to `False`, the result is computed only if the cache is + empty, otherwise the cache is returned. If set to `None`, the + result is not cached. Caching is only supported for basic 2D + transposes and the cache is reset after any in-place operations. + """ + + # Handle non-2D transposes + if -self.dim() <= dim0 < 0: + dim0 += self.dim() + if -self.dim() <= dim1 < 0: + dim1 += self.dim() + if self.dim() != 2 or dim0 == dim1: + if update_cache is not None: + raise ValueError( + "Transpose caching is only supported for basic 2D transposes " + f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" + ) + return super().transpose(dim0, dim1) + + # No caching. + if update_cache is None: + return self._transpose_no_cache() + + # Update cache. + if update_cache or self._transpose is None: + self._transpose = self._transpose_no_cache() + + return self._transpose + + @torch.no_grad() + def reset_fp8_meta_scale_inv(self) -> None: + """Replace FP8 meta tensor scale-inverse with cached value + + The FP8 meta tensor scale_inv entry corresponding to this + tensor is replaced with the scale_inv value used to construct + the tensor. + + """ + if self._fp8_meta is None: + return + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + scale_inv.view(1).copy_(self._scale_inv.view(1)) + + def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: + """Create `Float8Tensor` with given nominal dtype + + The new tensor has the same underlying FP8 data. + + """ + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + dtype=dtype, + ) + + def _reset_caches(self) -> None: + """Reset cached values + + Should be called after any in-place operation. + + """ + self._transpose = None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # In-place copy op + if func == aten.copy_.default: + + # Check tensors + dst = args[0] + src = args[1] + if not isinstance(dst, Float8Tensor): + raise RuntimeError("Expected to copy into Float8Tensor") + if not isinstance(src, torch.Tensor): + raise RuntimeError("Expected to copy from tensor") + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + + # Make sure input is in expected format + if isinstance(src, Float8Tensor): + src = src.from_float8() + src = src.expand(dst.size()) + src = src.to( + device=dst.device, + memory_format=torch.contiguous_format, + ) + + # Update scaling factor if FP8 meta tensors are available + if dst._fp8_meta is None: + scale = dst._scale_inv.reciprocal() + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] + dst._scale_inv = scale.detach().view(1).reciprocal() + + # Cast to FP8 + tex.cast_to_fp8_noalloc( + src.view(1,-1), + scale, + dst._data.view(1,-1), + torch.empty_like(dst._scale_inv), # amax + dst._scale_inv, + dst._fp8_dtype, + ) + + # Nothing to return for in-place ops + dst._reset_caches() + return None + + # Slice op + # TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme + # if these slices are modified in-place + if func == aten.slice.Tensor: + tensor = args[0] + data = tensor._data + data_slice = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_slice) + + # Detach op + if func == aten.detach.default: + # Simply return a new Float8Tensor with the same attrs + return Float8Tensor.make_like( + args[0], + data=args[0]._data, + fp8_attrs=args[0]._fp8_attrs, + ) + + def maybe_unwrap(t): + if isinstance(t, Float8Tensor): + return t.from_float8() + return t + + def maybe_update_inplace(arg, new_arg, schema_arg): + """Update values of FP8 tensors + + Keep the same FP8 scaling factors. + + """ + if( + isinstance(arg, Float8Tensor) and + isinstance(new_arg, torch.Tensor) and + hasattr(schema_arg, 'alias_info') and + hasattr(schema_arg.alias_info, 'is_write') and + schema_arg.alias_info.is_write + ): + arg.copy_(new_arg) + arg._reset_caches() + + # In-place op + if func._schema.is_mutable: + # Cast to higher precision, perform op, and cast values + # back to original FP8 buffers + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + out = super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op + # Note: cast to higher precision and perform op + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + def _get_data(self) -> Float8Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Cast tensor to FP8 and store in FP8 buffer. + + """ + with torch.no_grad(): + self.copy_(tensor) + + # Cast to FP8 when setting Float8Tensor.data + data = property(_get_data, _set_data) + + # Accessors for objects in self._fp8_attrs + # Note: We store FP8 attributes in a dictionary so we can share + # them between tensors with the same data, e.g. detached tensors. + # For convenience, we also expose them as property attributes. + _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) + _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) + _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) + _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) + _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) + + # Do not force the Float8Tensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c89ff10968..c7d4524113 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -17,7 +17,7 @@ from .jit import jit_fuser -__all__ = ["fp8_autocast"] +__all__ = ["fp8_autocast", "fp8_model_init"] def check_fp8_support() -> Tuple[bool, str]: @@ -59,6 +59,7 @@ class FP8GlobalStateManager: FP8_CALIBRATION = False FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None + FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False FP8_AUTOCAST_COUNTER = 0 FP8_CURRENT_CONTEXT_ID = 0 @@ -277,6 +278,11 @@ def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" return cls.FP8_CALIBRATION + @classmethod + def with_fp8_parameters(cls) -> bool: + """Should the parameters be stored as FP8""" + return cls.FP8_PARAMETERS + @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple @@ -400,6 +406,11 @@ def fp8_autocast_enter( fp8_group: Optional[dist_group_type] = None, ) -> None: """Set state and tracking variables for entry into FP8 region.""" + if cls.FP8_AUTOCAST_DEPTH == 0: + if callable(cls.amax_forward_global_reduce_func): + cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable + cls.delete_key_from_amax_buffer(forward=True) + cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe @@ -419,11 +430,6 @@ def fp8_autocast_exit(cls): """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 - if cls.FP8_AUTOCAST_DEPTH == 0: - if callable(cls.amax_forward_global_reduce_func): - cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable - cls.delete_key_from_amax_buffer(forward=True) - @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: """Copy the scaling factors and amaxes for recompute forward phase @@ -477,9 +483,45 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] +@contextmanager +def fp8_model_init(enabled: bool = True) -> None: + """ + Context manager for FP8 initialization of parameters. + + Example usage: + + .. code-block:: python + + with fp8_model_init(enabled=True): + model = transformer_engine.pytorch.Linear(768, 768) + + Parameters + ---------- + enabled: bool, default = `True` + when enabled, Transformer Engine modules created inside this `fp8_model_init` + region will hold only FP8 copies of its parameters, as opposed to the default + behavior where both higher precision and FP8 copies are present. Setting this + option to `True` may result in lower memory consumption and is especially + useful for scenarios like: + + * full model training using optimizer with master weights, where the high + precision copies of weights are already present in the optimizer. + * inference, where only the FP8 copies of the parameters are used. + * LoRA-like fine-tuning, where the main parameters of the model do not change. + + This functionality is *EXPERIMENTAL*. + """ + try: + _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + FP8GlobalStateManager.FP8_PARAMETERS = enabled + yield + finally: + FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment + + @contextmanager def fp8_autocast( - enabled: bool = False, + enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, @@ -508,7 +550,7 @@ def fp8_autocast( Parameters ---------- - enabled: bool, default = `False` + enabled: bool, default = `True` whether or not to enable fp8 calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale @@ -523,7 +565,10 @@ def fp8_autocast( """ try: fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group) + FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, + calibrating=calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group) yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5803cfa2f9..1dbc40dc70 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -36,6 +36,7 @@ cast_to_fp8, ) from ..constants import dist_group_type +from ..float8_tensor import Float8Tensor _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -451,21 +452,29 @@ def set_fp8_weights(self) -> None: setattr( self, weight_cast_attr, - torch.empty( - shape, - device=torch.cuda.current_device(), - dtype=torch.uint8, - ), + Float8Tensor( + data=torch.empty( + shape, + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, + ) ) setattr( self, weight_transpose_attr, - torch.empty( - shape[1], - shape[0], - device=torch.cuda.current_device(), - dtype=torch.uint8, - ), + Float8Tensor( + data=torch.empty( + shape[1], + shape[0], + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, + ) ) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -483,12 +492,17 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N # This routine is shared across FP8 and FP8_calibration paths so should not actually # assume FP8 execution. - def fp8_init(self, num_gemms: int = 1) -> None: + def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors() + if self.fp8 or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. if (self.fp8_initialized @@ -536,7 +550,7 @@ def prepare_forward( assert self.tp_group_initialized, "TP group not initialized." self.set_activation_dtype(inp) - self.fp8_init(num_gemms=num_gemms) + self.init_fp8_metadata(num_gemms=num_gemms) # Create persistent tensors for fp8 weights and their transposes # only when fp8 weight caching is used. @@ -765,7 +779,7 @@ def noop_cat(self, def get_fp8_weights_empty_tensors( self, is_first_microbatch: Union[bool, None], - ) -> List[torch.Tensor]: + ) -> List[Float8Tensor]: """ Returns empty tensors to be later used to store fp8 version of weights and their transposes (for the bwd pass) for this batch (or microbatch). @@ -781,23 +795,42 @@ def get_fp8_weights_empty_tensors( fp8_weight_tensors = [] for shape in self.fp8_weight_shapes: fp8_weight_tensors.append( - torch.empty( - shape, - device=torch.cuda.current_device(), - dtype=torch.uint8, + Float8Tensor( + data=torch.empty( + shape, + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, ) ) - fp8_weight_tensors.append( - torch.empty( - shape[1], - shape[0], - device=torch.cuda.current_device(), - dtype=torch.uint8, + Float8Tensor( + data=torch.empty( + shape[1], + shape[0], + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=tex.DType.kFloat8E4M3, + fp8_scale_inv=1, ) ) return fp8_weight_tensors + def state_dict(self, *args, **kwargs) -> Dict: + """Get dictionary containing module state""" + state = super().state_dict(*args, **kwargs) + + # Convert Float8Tensors to plain tensors + # Note: Float8Tensors don't serialize well, especially if they + # contain references to FP8 metadata. + for key, val in state.items(): + if isinstance(val, Float8Tensor): + state[key] = val.from_float8() + + return state @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a8e83631bc..d4746ba3a0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -23,7 +23,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -43,6 +43,7 @@ from ._common import _apply_normalization +from ..float8_tensor import Float8Tensor __all__ = ["LayerNormLinear"] @@ -79,10 +80,11 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + normalization: str, + primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_ag: bool, - normalization: str, ub_atomic_gemm_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -159,28 +161,43 @@ def forward( ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weight is already in FP8 + weight.reset_fp8_meta_scale_inv() + weight_fp8 = weight + weight_t_fp8 = None + if is_grad_enabled: + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + + elif update_fp8_weights: + # Need to cast weights to FP8 + weight_fp8 = Float8Tensor( + data=weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) if is_grad_enabled: tex.fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, + cast_out=weight_fp8._data, + transpose_out=weight_t_fp8._data, ) else: - weight_t_fp8 = None - weight_fp8 = tex.cast_to_fp8( + weight_fp8._data = tex.cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward) + fp8_dtype_forward, + ) + weight_t_fp8 = None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo out, _ = tex.fp8_gemm( - weight_fp8, + weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -356,7 +373,7 @@ def backward( # DGRAD: Evaluated unconditionally to feed into Linear backward _ = tex.fp8_gemm( - weight_t_fp8, + weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -544,6 +561,7 @@ def backward( None, None, None, + None, ) @@ -646,10 +664,10 @@ def __init__( return_layernorm_output: bool = False, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -666,6 +684,7 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag @@ -719,18 +738,30 @@ def __init__( self.layer_norm_bias = None self.reset_layer_norm_parameters() - self.weight_tensor = torch.empty( + temp_weight = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.weight_tensor, + temp_weight, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + if self.primary_weights_in_fp8: + self.init_fp8_metadata() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.weight_tensor = Float8Tensor.to_float8( + temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + else: + self.weight_tensor = temp_weight + if self.use_bias: self.bias_tensor = torch.empty( self.out_features, @@ -769,10 +800,17 @@ def __init__( bname = pname + "bias" slice_end = slice_begin + slice_size - - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) - ) + # NOTE(future): Figure out a way to support slicing when weights + # are of `Float8Tensor` class + if self.primary_weights_in_fp8: + assert len(parameters_split) == 1, ("Slicing operation is not " + "supported in Float8Tensor " + "class!") + self.register_parameter(wname, Parameter(self.weight_tensor)) + else: + self.register_parameter( + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) + ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), @@ -833,7 +871,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: @@ -877,6 +915,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() @@ -927,10 +967,11 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.normalization, + self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_ag, - self.normalization, self.ub_atomic_gemm_ag, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d41c8d39df..40256dba6a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -47,6 +47,7 @@ from ..constants import dist_group_type, TE_DType from ..jit import no_torch_dynamo +from ..float8_tensor import Float8Tensor from ._common import _apply_normalization @@ -105,14 +106,15 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + activation: str, + normalization: str, + primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_rs: bool, ub_atomic_gemm_rs: bool, ub_split_ag: bool, ub_atomic_gemm_ag: bool, - activation: str, - normalization: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -196,45 +198,68 @@ def forward( fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weights are already in FP8 + fc1_weight.reset_fp8_meta_scale_inv() + fc2_weight.reset_fp8_meta_scale_inv() + fc1_weight_fp8 = fc1_weight + fc2_weight_fp8 = fc2_weight + fc1_weight_t_fp8 = None + fc2_weight_t_fp8 = None if is_grad_enabled: + fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch) + + elif update_fp8_weights: + # Need to cast weights to FP8 + fc1_weight_fp8 = Float8Tensor( + data=fc1_weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + fc2_weight_fp8 = Float8Tensor( + data=fc2_weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + if is_grad_enabled: + # Fused cast-transpose kernels tex.fp8_cast_transpose_fused( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=fc1_weight_fp8, - transpose_out=fc1_weight_t_fp8, + cast_out=fc1_weight_fp8._data, + transpose_out=fc1_weight_t_fp8._data, ) - tex.fp8_cast_transpose_fused( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, - cast_out=fc2_weight_fp8, - transpose_out=fc2_weight_t_fp8, + cast_out=fc2_weight_fp8._data, + transpose_out=fc2_weight_t_fp8._data, ) else: - fc1_weight_t_fp8 = None - fc1_weight_fp8 = tex.cast_to_fp8( + fc1_weight_fp8._data = tex.cast_to_fp8( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) - fc2_weight_t_fp8 = None - fc2_weight_fp8 = tex.cast_to_fp8( + fc1_weight_t_fp8 = None + fc2_weight_fp8._data = tex.cast_to_fp8( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, ) + fc2_weight_t_fp8 = None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo fc1_out, _ = tex.fp8_gemm( - fc1_weight_fp8, + fc1_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -283,7 +308,7 @@ def forward( ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = tex.fp8_gemm( - fc2_weight_fp8, + fc2_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -530,7 +555,7 @@ def backward( ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_t_fp8, + fc2_weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -645,7 +670,7 @@ def backward( ) # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( - fc1_weight_t_fp8, + fc1_weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -908,6 +933,7 @@ def backward( None, None, None, + None, ) @@ -1020,12 +1046,12 @@ def __init__( micro_batch_size: Optional[int] = None, set_parallel_mode: bool = False, zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_rs: bool = False, ub_atomic_gemm_rs: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -1043,6 +1069,7 @@ def __init__( self.activation == 'gelu') self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_rs = ub_split_rs @@ -1102,19 +1129,30 @@ def __init__( else: fc1_output_features = self.size_per_partition # FC1 init - self.fc1_weight = Parameter( - torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype) - ) - self.fp8_weight_shapes.append(self.fc1_weight.shape) + fc1_temp_weight = torch.empty( + fc1_output_features, hidden_size, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.fc1_weight, + fc1_temp_weight, init_method, get_rng_state_tracker, - partition_dim=0, - stride=1, + set_tp_attributes=False, ) + if self.primary_weights_in_fp8: + self.init_fp8_metadata(num_gemms=2) + self.fp8_meta["update_amax_and_scale_fwd"] = True + + fc1_temp_weight = Float8Tensor.to_float8( + fc1_temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + self.fc1_weight = Parameter(fc1_temp_weight) + set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) + self.fp8_weight_shapes.append(self.fc1_weight.shape) + if self.use_bias: self.fc1_bias = Parameter( torch.empty(fc1_output_features, device=device, dtype=params_dtype) @@ -1127,19 +1165,27 @@ def __init__( self.fc1_bias.zero_() # FC2 init - self.fc2_weight = Parameter( - torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype) - ) - self.fp8_weight_shapes.append(self.fc2_weight.shape) + fc2_temp_weight = torch.empty( + hidden_size, self.size_per_partition, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.fc2_weight, + fc2_temp_weight, output_layer_init_method, get_rng_state_tracker, - partition_dim=1, - stride=1, + set_tp_attributes=False, ) + if self.primary_weights_in_fp8: + fc2_temp_weight = Float8Tensor.to_float8( + fc2_temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + + self.fc2_weight = Parameter(fc2_temp_weight) + set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) + self.fp8_weight_shapes.append(self.fc2_weight.shape) + if self.use_bias: self.fc2_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) @@ -1192,7 +1238,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None, None, None] if is_first_microbatch is None: @@ -1235,6 +1281,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \ self.get_fp8_weights_scratchpad( @@ -1279,14 +1327,15 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.activation, + self.normalization, + self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_rs, self.ub_atomic_gemm_rs, self.ub_split_ag, self.ub_atomic_gemm_ag, - self.activation, - self.normalization, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5e2cab22fe..b14877e74b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -45,6 +45,8 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo +from ..float8_tensor import Float8Tensor + __all__ = ["Linear"] @@ -57,9 +59,9 @@ class _Linear(torch.autograd.Function): @staticmethod def forward( ctx, - weight: torch.Tensor, - weight_fp8: Union[torch.Tensor, None], - weight_t_fp8: Union[torch.Tensor, None], + weight: Union[Float8Tensor, torch.Tensor], + weight_fp8: Union[Float8Tensor, None], + weight_t_fp8: Union[Float8Tensor, None], inp: torch.Tensor, bias: torch.Tensor, use_bias: bool, @@ -75,6 +77,7 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, + primary_weights_in_fp8: bool, ub_split_rs: bool, ub_split_ag: bool, ub_atomic_gemm_rs: bool, @@ -141,24 +144,38 @@ def forward( ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weight is already in FP8 + weight.reset_fp8_meta_scale_inv() + weight_fp8 = weight + weight_t_fp8 = None + if is_grad_enabled: + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + + elif update_fp8_weights: + # Need to cast weights to FP8 + weight_fp8 = Float8Tensor( + data=weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) if is_grad_enabled: fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, + cast_out=weight_fp8._data, + transpose_out=weight_t_fp8._data, ) else: - weight_t_fp8 = None - weight_fp8 = cast_to_fp8( + weight_fp8._data = cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) + weight_t_fp8 = None proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( None, None, None, activation_dtype) @@ -184,7 +201,7 @@ def forward( ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = fp8_gemm( - weight_fp8, + weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -245,6 +262,9 @@ def forward( if is_grad_enabled: fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad + if fp8: + assert hasattr(weight_t_fp8, "_data"), \ + "_data attr doesn't exist (before save for bwd)" ctx.save_for_backward( inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, inputmat_t if weight.requires_grad and fp8_wgrad else None, @@ -294,6 +314,9 @@ def backward( weight_t_fp8, fwd_scale_inverses, ) = ctx.saved_tensors + if weight_t_fp8 is not None: + assert hasattr(weight_t_fp8, "_data"), \ + "_data attr doesn't exist (after restore in bwd)" if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -349,7 +372,7 @@ def backward( if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( - weight_t_fp8, + weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -470,6 +493,7 @@ def backward( None, None, None, + None, ) @@ -554,9 +578,9 @@ def __init__( params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, + device: Union[torch.device, str] = "cuda", ub_split_rs: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_rs: bool = False, ub_atomic_gemm_ag: bool = False, ) -> None: @@ -570,6 +594,7 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.parameters_split = parameters_split + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_rs = ub_atomic_gemm_rs @@ -609,18 +634,31 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - self.weight_tensor = torch.empty( + temp_weight = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) + # TODO(ksivaman): This functionality works with FP8 outside TE. initialize_affine_weight_gpu( - self.weight_tensor, + temp_weight, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + if self.primary_weights_in_fp8: + self.init_fp8_metadata() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.weight_tensor = Float8Tensor.to_float8( + temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + else: + self.weight_tensor = temp_weight + if self.use_bias: self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) else: @@ -657,9 +695,17 @@ def __init__( slice_end = slice_begin + slice_size - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) - ) + # TODO(ksivaman): Add indexing op to torch dispatcher for float8 + if self.primary_weights_in_fp8: + assert len(parameters_split) == 1, ("Slicing operation is not " + "supported in Float8Tensor " + "class!") + self.register_parameter(wname, Parameter(self.weight_tensor)) + else: + + self.register_parameter( + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) + ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), @@ -697,13 +743,13 @@ def __init__( def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], - ) -> List[torch.Tensor]: + ) -> List[Float8Tensor]: """ Fetch the fp8 weight tensor placeholders if they exist (when `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: @@ -747,6 +793,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() @@ -790,6 +838,7 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), + self.primary_weights_in_fp8, self.ub_split_rs, self.ub_split_ag, self.ub_atomic_gemm_rs,