diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eaededc4..1206012195 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail " python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" diff --git a/tests/pytorch/references/blockwise_fp8_gemm_reference.py b/tests/pytorch/references/blockwise_fp8_gemm_reference.py new file mode 100644 index 0000000000..5aef986e37 --- /dev/null +++ b/tests/pytorch/references/blockwise_fp8_gemm_reference.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128): + pid = tl.program_id(0) + idx = pid * BLOCK + tl.arange(0, BLOCK) + mask = idx < M * N + + row = idx // N + col = idx % N + + y_offset = row * y_str0 + col * y_str1 + x_offset = row * N + col + s_offset = row * N + col + + y = tl.load(y_ptr + y_offset, mask=mask) + x = tl.load(x_ptr + x_offset, mask=mask) + s = tl.load(s_ptr + s_offset, mask=mask) + + tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask) + + +def fused_fma(y, x, s, BLOCK=128): + """ + Fused multiply-add operation (y = y + x * s). + + PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation). + This function also supports cases where 'y' is non-contiguous in memory. + """ + + assert ( + y.shape == x.shape == s.shape and y.dim() == 2 + ), "All tensors must be 2D with the same shape" + assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous" + + M, N = y.shape + grid = ((M * N + BLOCK - 1) // BLOCK,) + + fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK) + + return y + + +class CuBLASRefBlockwiseGemm: + """ + A cuBLAS compatible reference implementation of subchannel GEMM. + """ + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + demunged_sx: torch.Tensor, + demunged_sw: torch.Tensor, + quant_tile_shape_x: Tuple[int, int], + quant_tile_shape_w: Tuple[int, int], + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + use_split_accumulator: bool = False, + ) -> torch.Tensor: + # demunge scale shapes for cuBLAS + is_a_1d_scaled = quant_tile_shape_x[0] == 1 + is_b_1d_scaled = quant_tile_shape_w[0] == 1 + M, K = qx.shape + N, K = qw.shape + + # mm_tile_shape = (tile_m, tile_n, tile_k) + mm_tile_shape = ( + quant_tile_shape_x[0], + quant_tile_shape_w[0], + quant_tile_shape_w[1], + ) + if bias is not None and bias.numel(): + # To match cuBLAS more closely when bias is applied, + # the reference accumulates into float32, and cast to + # bfloat16 is deferred until after the GEMM. + out_dtype_for_ref = torch.float32 + else: + out_dtype_for_ref = out_dtype + y = self.qgemm_blockwise_2d( + qx, + qw, + out_dtype_for_ref, + demunged_sx, + demunged_sw, + mm_tile_shape, + use_split_accumulator, + is_a_1d_scaled, + is_b_1d_scaled, + ) + if bias is not None and bias.numel(): + y += bias + y = y.to(dtype=out_dtype) + # cublas accumulation first convert to output dtype, then accumulate. + if accumulate: + assert out is not None + y = y + out + else: + assert out is None, "Output tensor should be None when accumulate is False." + + return y + + @classmethod + def qgemm_blockwise_2d( + cls, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + mm_tile_shape: Tuple[int, int, int], + use_split_accumulator: bool, + is_a_1d_scaled: bool, + is_b_1d_scaled: bool, + ) -> torch.Tensor: + """ + Difference between cuBLAS and CUTLASS GEMM implementations: + - cuBLAS accumulation equation: use different equation for each scaling mode. + - For accumulation C in epiloge, it first convert C to output dtype, then accumulate. + """ + + M, K = qx.shape + N, K_w = qw.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + tile_len = 128 + # Calculate grid sizes without padding + grid_m = (M + tile_len - 1) // tile_len + grid_n = (N + tile_len - 1) // tile_len + grid_k = (K + tile_len - 1) // tile_len + + block_m, block_n, block_k = mm_tile_shape + scale_m_per_tile = tile_len // block_m + scale_n_per_tile = tile_len // block_n + assert block_k == tile_len, "block_k must be equal to tile_len" + + # Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM: + # 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers. + # 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # Validate shapes of sx and sw + scale_m_per_tensor = (M + block_m - 1) // block_m + scale_n_per_tensor = (N + block_n - 1) // block_n + assert sx.shape == ( + scale_m_per_tensor, + grid_k, + ), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}" + assert sw.shape == ( + scale_n_per_tensor, + grid_k, + ), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}" + + for i in range(grid_m): + m_start = i * tile_len + m_end = min(m_start + tile_len, M) + m_size = m_end - m_start + + for j in range(grid_n): + n_start = j * tile_len + n_end = min(n_start + tile_len, N) + n_size = n_end - n_start + + y_block = y[m_start:m_end, n_start:n_end] + + for k in range(grid_k): + k_start = k * tile_len + k_end = min(k_start + tile_len, K) + k_size = k_end - k_start + + qx_block = ( + qx[m_start:m_end, k_start:k_end].clone().contiguous() + ) # Shape: [m_size, k_size] + qw_block = ( + qw[n_start:n_end, k_start:k_end].clone().contiguous() + ) # Shape: [n_size, k_size] + + # Extract scaling factors for the current blocks + sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze( + -1 + ) + sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0) + + # Perform qgemm with scaling factors fused in the GEMM + # Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM + one = torch.tensor(1.0, dtype=torch.float32, device=qx.device) + y_partial = torch._scaled_mm( + qx_block, + qw_block.t(), + scale_a=one, + scale_b=one, + out_dtype=torch.float32, + use_fast_accum=not use_split_accumulator, + ) + + # Accumulate the partial result + if is_a_1d_scaled and is_b_1d_scaled: + # 1Dx1D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + # Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM + # y_block.add_(y_partial, alpha=scale.item()) + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + elif not is_a_1d_scaled and is_b_1d_scaled: + # 2Dx1D + # CuBLAS accumulation equation: y += (y * scale_b) * scale_a + y_partial = y_partial * sw_block + fused_fma( + y_block, + y_partial, + sx_block.expand_as(y_partial).contiguous(), + ) + elif is_a_1d_scaled and not is_b_1d_scaled: + # 1Dx2D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + else: + scale = sx_block * sw_block + fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous()) + + y = y.to(out_dtype) + return y diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index b98966f514..f5c9dc0e96 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -49,6 +49,7 @@ def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1) return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + @classmethod def demunge_scale_shape_from_backend( cls, qtensor_shape: Tuple[int, int], diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py new file mode 100644 index 0000000000..9a1cfa2db8 --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -0,0 +1,975 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from transformer_engine.pytorch.utils import get_device_compute_capability +from references.blockwise_quantizer_reference import CuBLASScaleMunger +from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm + + +def fp8_blockwise_gemm_supported() -> bool: + return ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ) + + +def cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + atol: float = 0.0, + rtol: float = 0.0 +): + if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2: + pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2") + if not (is_x_1d_scaled or is_w_1d_scaled): + pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile") + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + if noise_type == "uniform": + x = torch.rand(x_shape, dtype=torch.float32, device=device) * x_magnitude * 2 - x_magnitude + w = torch.rand(w_shape, dtype=torch.float32, device=device) * w_magnitude * 2 - w_magnitude + elif noise_type == "normal": + x = torch.randn(x_shape, dtype=torch.float32, device=device) * x_magnitude + w = torch.randn(w_shape, dtype=torch.float32, device=device) * w_magnitude + else: + assert False + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude + else: + out = None + + assert not (use_bias and use_grad), "Bias grad not supported by GEMM" + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Reference GEMM + ref_gemm = CuBLASRefBlockwiseGemm() + scale_decoder = CuBLASScaleMunger() + qx_data = ( + qx._columnwise_data.view(dtype=x_dtype) + if x_columnwise + else qx._rowwise_data.view(dtype=x_dtype) + ) + qw_data = ( + qw._columnwise_data.view(dtype=w_dtype) + if w_columnwise + else qw._rowwise_data.view(dtype=w_dtype) + ) + ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv + ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv + y_ref = ref_gemm.qgemm( + qx=qx_data, + qw=qw_data, + out_dtype=out_dtype, + demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape + ), + demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape + ), + quant_tile_shape_x=x_quant_tile_shape, + quant_tile_shape_w=w_quant_tile_shape, + bias=bias, + out=out.clone() if accumulate else None, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM" + aux_tensor = torch.randn((M, N), dtype=out_dtype, device=device) if use_gelu else None + aux_tensor_ref = aux_tensor.clone() if use_gelu else None + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + aux_tensor, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y are not the same tensor + assert y_ref is not y, "y_ref and y should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y = torch.where(y.isnan(), torch.zeros_like(y), y) + + if use_gelu: + # Check + if use_grad: + # With use_grad, GEMM should use aux tensor to calculate + # gradient + gelu_ref = tex.dgelu(y_ref, aux_tensor_ref, None) + # TODO: How do we decide whether this is acceptably close? + # Could also try to put the activation inside the reference + # before the output cast to see different tolerances. + torch.testing.assert_close(y, gelu_ref, atol=1e-3, rtol=1e-2) + else: + # aux tensor is pre-gelu aux output. Verify against y_ref. + torch.testing.assert_close(aux_tensor, y_ref, atol=atol, rtol=rtol) + act = torch.nn.GELU() + gelu_ref = act(y_ref) + # gelu_ref = tex.gelu(y_ref, None) + torch.testing.assert_close(y, gelu_ref, atol=atol, rtol=rtol) + else: + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + +def cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED", + expected_err_cls=RuntimeError +): + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + x = torch.rand(x_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + w = torch.rand(w_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=True, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + grad = use_grad + gelu_in = None if not use_gelu else torch.randn((M, N), dtype=out_dtype, device=device) + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + with pytest.raises(expected_err_cls, match=expected_err_msg): + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_in, + grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (128, 128, 128), + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 128, 336), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (320, 256, 336), + (1024, 4096, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_shape_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + (320, 256, 336), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 64, 336), + # k > 128 + (256, 256, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_bias( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_bias=True, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["colxrow", "colxcol", "rowxcol"], +) +def test_cublas_gemm_fp8_blockwise_columnwise( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + is_x_columnwise, + is_w_columnwise, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (320, 64, 336), + # k > 128 + (256, 256, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "use_grad", + [ + True, + ], + ids=["grad"], +) +def test_cublas_gemm_fp8_gelu( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad, +): + # NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed + # so the epilogue is disabled on the transformer engine side. + if not use_grad and not (is_x_1d_scaled and not is_w_1d_scaled): + pytest.skip( + "CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)." + ) + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_gelu=True, + use_grad=use_grad, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [False], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_split_accumulator_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_bgrad_not_supported( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # NOTE: BGRAD epilogue is not supported for fp8. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=True, + use_bias=True, + expected_err_msg="Epilogue requested outside of the available", + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no_bias"]) +@pytest.mark.parametrize("use_grad", [True, False], ids=["grad", "no_grad"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_gelu_unsupported_cases_error( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_bias, + use_grad, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + if use_grad and not use_bias and out_dtype == torch.bfloat16: + pytest.skip("DGELU epilogue is supported for bfloat16.") + elif use_grad and not use_bias: + expected_err = "an unsupported value or parameter was passed" + else: + expected_err = "Epilogue requested outside of the available" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=use_grad, + use_bias=use_bias, + use_gelu=True, + expected_err_msg=expected_err, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_illegal_dtype_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # e5m2 by e5m2 not supported. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (False, False), + ], + ids=["2Dx2D"], +) +def test_illegal_2D_by_2D_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # 2D block quantization by 2D block quantization is not supported. + expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg=expected_err_msg, + ) + + +@pytest.mark.parametrize( + "M, K, N, legalX1d, legalX2d", + [ + # M dim unconstrained when X is 2D. + (255, 128, 256, False, True), + # K must be multiple of 16 + (256, 120, 256, False, False), + # N must be a multiple of 8 + (256, 128, 252, False, False), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (False, True), + (True, True), + ], + ids=["1Dx2D", "2Dx1D", "1Dx1D"], +) +def test_unaligned_shapes( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + legalX1d, + legalX2d, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + legal = legalX1d if is_x_1d_scaled else legalX2d + if not legal: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg="dimension requirement", + ) + else: + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + "uniform", # noise type + 1.0, # x_magnitude + 1.0, # w_magnitude + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f19465c44b..6fe3539257 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -52,97 +52,173 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); } +/* Parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ struct GemmParam { - void *A; - void *B; - cublasOperation_t transA; - cublasOperation_t transB; - transformer_engine::DType Atype; - transformer_engine::DType Btype; - void *A_scale_inv; - void *B_scale_inv; - int lda; - int ldb; - - GemmParam(cublasOperation_t transA, cublasOperation_t transB) - : A(nullptr), - B(nullptr), - transA(transA), - transB(transB), - Atype(transformer_engine::DType::kNumTypes), - Btype(transformer_engine::DType::kNumTypes), - A_scale_inv(nullptr), - B_scale_inv(nullptr), - lda(0), - ldb(0) {} + void *A = nullptr; + void *B = nullptr; + cublasOperation_t transA = CUBLAS_OP_N; + cublasOperation_t transB = CUBLAS_OP_N; + transformer_engine::DType Atype = transformer_engine::DType::kNumTypes; + transformer_engine::DType Btype = transformer_engine::DType::kNumTypes; + void *A_scale_inv = nullptr; + void *B_scale_inv = nullptr; + int lda = 0; // A column strides + int ldb = 0; // B column strides }; +/* Populate parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - const int k, const int lda, const int ldb) { + int m, int n, int k) { using namespace transformer_engine; - // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. - // Must either force them both into a common block scaling mode or loosen this - // restriction. - NVTE_CHECK(A.scaling_mode == B.scaling_mode, - "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK( + A.scaling_mode == B.scaling_mode || + (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || + (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), + "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); - GemmParam ret(transA, transB); + GemmParam ret; + + // Device compute capability + const int arch = cuda::sm_arch(); - ret.lda = lda; - ret.ldb = ldb; + // Transpose mode with column-major ordering + bool transa_bool = transA == CUBLAS_OP_T; + bool transb_bool = transB == CUBLAS_OP_T; - // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases - // or need to be treated as `is_tensor_scaling`. + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.A = A.data.dptr; + ret.transA = transA; + ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - if (transA == CUBLAS_OP_T) { - ret.Atype = A.data.dtype; - } else { - ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; - if (is_fp8_dtype(ret.Atype)) { - int arch = cuda::sm_arch(cuda::current_device()); - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); - ret.A = A.columnwise_data.dptr; - ret.transA = CUBLAS_OP_T; - ret.A_scale_inv = A.columnwise_scale_inv.dptr; - ret.lda = k; - } + ret.lda = transa_bool ? k : m; + if (arch < 100 && !transa_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } + } else if (is_mxfp_scaling(A.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + } + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = transA; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = m; + } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); + } + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.lda % 16) == 0, + "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. + // Smallest supported CType is 2 bytes in this scaling mode. + NVTE_CHECK((m % 8) == 0, + "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } else { + NVTE_ERROR("A has unsupported scaling mode"); + } + + // Configure B matrix + if (is_tensor_scaling(B.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.B = B.data.dptr; + ret.transB = transB; + ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; - if (transB == CUBLAS_OP_T) { - ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; - if (is_fp8_dtype(ret.Btype)) { - int arch = cuda::sm_arch(cuda::current_device()); - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); - ret.B = B.columnwise_data.dptr; - ret.transB = CUBLAS_OP_N; - ret.B_scale_inv = B.columnwise_scale_inv.dptr; - ret.ldb = k; - } + ret.ldb = transb_bool ? n : k; + if (arch < 100 && transb_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } + } + } else if (is_mxfp_scaling(B.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = transB; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { - ret.Btype = B.data.dtype; + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + + // Requirements from + // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.ldb % 16) == 0, + "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { + // Observed this requirement only present for B tensor is 1D quantized. + NVTE_CHECK((n % 8) == 0, + "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } } else { - // If not tensor scaling (which includes also high precision types), we need to - // use the proper version of data - // We leave the transA/B values as is, since Blackwell supports transposes - ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; - ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; - ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + NVTE_ERROR("B has unsupported scaling mode"); } + return ret; } @@ -153,18 +229,33 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, - int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, - void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, cudaStream_t stream) { + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + // Tensor dims in row-major order + const int A0 = inputA->flat_first_dim(); + const int A1 = inputA->flat_last_dim(); + const int B0 = inputB->flat_first_dim(); + const int B1 = inputB->flat_last_dim(); + + // GEMM dims in column-major order + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, + "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, + ")"); + const int ldd = m; + // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { return; } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); + const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -226,6 +317,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, param.transA == CUBLAS_OP_N ? k : m, param.lda)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, param.transB == CUBLAS_OP_N ? n : k, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); @@ -249,12 +341,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized - // GEMM types. - // Scaling factors. #if CUDA_VERSION >= 12080 - cublasLtMatmulMatrixScale_t scaling_mode; + cublasLtMatmulMatrixScale_t scaling_mode_a; + cublasLtMatmulMatrixScale_t scaling_mode_b; #endif if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { void *A_scale_inverse = param.A_scale_inv; @@ -266,8 +356,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); #if CUDA_VERSION >= 12080 - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; - } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -276,7 +367,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. if (cublasLtGetVersion() <= 120803) { @@ -285,7 +377,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } -#endif + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { +#if CUDA_VERSION >= 12090 + float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; +#else + NVTE_ERROR("FP8 block scaling requires CUDA 12.9+"); +#endif // CUDA_VERSION >= 12090 +#endif // CUDA_VERSION >= 12080 } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + to_string(inputB->scaling_mode) + "."); @@ -293,9 +410,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUDA_VERSION >= 12080 NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); #endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output @@ -305,8 +422,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUDA_VERSION >= 12080 - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + // NOTE: In all current cases where FP8 output is supported, the input is + // scaled identically to the output. + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_D_SCALE_MODE, + &scaling_mode_a, sizeof(scaling_mode_a))); #endif // For FP8 output, cuBLAS requires C_type to match bias_type and // be FP16/BF16 @@ -364,6 +484,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } + if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) || + (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) { + NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS || + epilogue == CUBLASLT_EPILOGUE_DGELU), + "Epilogue requested outside of the available and tested cuBLAS functionality for " + "float8 block scaled GEMM"); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); @@ -411,7 +539,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C @@ -469,35 +596,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const size_t A0 = inputA->flat_first_dim(); - const size_t A1 = inputA->flat_last_dim(); - const size_t B0 = inputB->flat_first_dim(); - const size_t B1 = inputB->flat_last_dim(); - - const int m = transa ? A0 : A1; - const int k = transa ? A1 : A0; - const int n = transb ? B1 : B0; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, 0, 0, false, nullptr, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -525,31 +626,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, + inputCounter, stream); } void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index dae39d82bf..f6b6ae22c2 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -27,7 +27,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -57,7 +57,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (cudnn_backend) { // TODO: add check for GPU ARCH diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 8519fe1b64..c56f9ef407 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -23,7 +23,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -47,7 +47,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;