From fbcbcb0924a16c8301e3e02ca0a5a1a266606179 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 28 Feb 2025 15:30:07 -0800 Subject: [PATCH 01/21] Add GEMM logic for blockwise quantized tensors. GEMM test cases included in pytorch integration. Signed-off-by: Keith Wyss --- .../blockwise_fp8_gemm_reference.py | 238 +++++ .../blockwise_quantizer_reference.py | 1 + .../test_float8_blockwise_gemm_exact.py | 832 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 263 ++++-- .../common/normalization/layernorm/ln_api.cpp | 4 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 4 +- .../csrc/extensions/type_converters.cpp | 7 + 7 files changed, 1262 insertions(+), 87 deletions(-) create mode 100644 tests/pytorch/references/blockwise_fp8_gemm_reference.py create mode 100644 tests/pytorch/test_float8_blockwise_gemm_exact.py diff --git a/tests/pytorch/references/blockwise_fp8_gemm_reference.py b/tests/pytorch/references/blockwise_fp8_gemm_reference.py new file mode 100644 index 0000000000..3487dfb810 --- /dev/null +++ b/tests/pytorch/references/blockwise_fp8_gemm_reference.py @@ -0,0 +1,238 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128): + pid = tl.program_id(0) + idx = pid * BLOCK + tl.arange(0, BLOCK) + mask = idx < M * N + + row = idx // N + col = idx % N + + y_offset = row * y_str0 + col * y_str1 + x_offset = row * N + col + s_offset = row * N + col + + y = tl.load(y_ptr + y_offset, mask=mask) + x = tl.load(x_ptr + x_offset, mask=mask) + s = tl.load(s_ptr + s_offset, mask=mask) + + tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask) + + +def fused_fma(y, x, s, BLOCK=128): + """ + Fused multiply-add operation (y = y + x * s). + + PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation). + This function also supports cases where 'y' is non-contiguous in memory. + """ + + assert ( + y.shape == x.shape == s.shape and y.dim() == 2 + ), "All tensors must be 2D with the same shape" + assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous" + + M, N = y.shape + grid = ((M * N + BLOCK - 1) // BLOCK,) + + fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK) + + return y + + +class CuBLASRefBlockwiseGemm: + """ + A cuBLAS compatible reference implementation of subchannel GEMM. + """ + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + demunged_sx: torch.Tensor, + demunged_sw: torch.Tensor, + quant_tile_shape_x: Tuple[int, int], + quant_tile_shape_w: Tuple[int, int], + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + use_split_accumulator: bool = False, + ) -> torch.Tensor: + # demunge scale shapes for cuBLAS + is_a_1d_scaled = quant_tile_shape_x[0] == 1 + is_b_1d_scaled = quant_tile_shape_w[0] == 1 + M, K = qx.shape + N, K = qw.shape + + # mm_tile_shape = (tile_m, tile_n, tile_k) + mm_tile_shape = ( + quant_tile_shape_x[0], + quant_tile_shape_w[0], + quant_tile_shape_w[1], + ) + if bias is not None and bias.numel(): + # To match cuBLAS more closely when bias is applied, + # the reference accumulates into float32, and cast to + # bfloat16 is deferred until after the GEMM. + out_dtype_for_ref = torch.float32 + else: + out_dtype_for_ref = out_dtype + y = self.qgemm_blockwise_2d( + qx, + qw, + out_dtype_for_ref, + demunged_sx, + demunged_sw, + mm_tile_shape, + use_split_accumulator, + is_a_1d_scaled, + is_b_1d_scaled, + ) + if bias is not None and bias.numel(): + y += bias + y = y.to(dtype=out_dtype) + # cublas accumulation first convert to output dtype, then accumulate. + if accumulate: + assert out is not None + y = y + out + else: + assert out is None, "Output tensor should be None when accumulate is False." + + return y + + @classmethod + def qgemm_blockwise_2d( + cls, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + mm_tile_shape: Tuple[int, int, int], + use_split_accumulator: bool, + is_a_1d_scaled: bool, + is_b_1d_scaled: bool, + ) -> torch.Tensor: + """ + Difference between cuBLAS and CUTLASS GEMM implementations: + - cuBLAS accumulation equation: use different equation for each scaling mode. + - For accumulation C in epiloge, it first convert C to output dtype, then accumulate. + """ + + M, K = qx.shape + N, K_w = qw.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + tile_len = 128 + # Calculate grid sizes without padding + grid_m = (M + tile_len - 1) // tile_len + grid_n = (N + tile_len - 1) // tile_len + grid_k = (K + tile_len - 1) // tile_len + + block_m, block_n, block_k = mm_tile_shape + scale_m_per_tile = tile_len // block_m + scale_n_per_tile = tile_len // block_n + assert block_k == tile_len, "block_k must be equal to tile_len" + + # Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM: + # 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers. + # 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # Validate shapes of sx and sw + scale_m_per_tensor = (M + block_m - 1) // block_m + scale_n_per_tensor = (N + block_n - 1) // block_n + assert sx.shape == ( + scale_m_per_tensor, + grid_k, + ), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}" + assert sw.shape == ( + scale_n_per_tensor, + grid_k, + ), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}" + + for i in range(grid_m): + m_start = i * tile_len + m_end = min(m_start + tile_len, M) + m_size = m_end - m_start + + for j in range(grid_n): + n_start = j * tile_len + n_end = min(n_start + tile_len, N) + n_size = n_end - n_start + + y_block = y[m_start:m_end, n_start:n_end] + + for k in range(grid_k): + k_start = k * tile_len + k_end = min(k_start + tile_len, K) + k_size = k_end - k_start + + qx_block = ( + qx[m_start:m_end, k_start:k_end].clone().contiguous() + ) # Shape: [m_size, k_size] + qw_block = ( + qw[n_start:n_end, k_start:k_end].clone().contiguous() + ) # Shape: [n_size, k_size] + + # Extract scaling factors for the current blocks + sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze( + -1 + ) + sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0) + + # Perform qgemm with scaling factors fused in the GEMM + # Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM + one = torch.tensor(1.0, dtype=torch.float32, device=qx.device) + y_partial = torch._scaled_mm( + qx_block, + qw_block.t(), + scale_a=one, + scale_b=one, + out_dtype=torch.float32, + use_fast_accum=not use_split_accumulator, + ) + + # Accumulate the partial result + if is_a_1d_scaled and is_b_1d_scaled: + # 1Dx1D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + # Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM + # y_block.add_(y_partial, alpha=scale.item()) + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + elif not is_a_1d_scaled and is_b_1d_scaled: + # 2Dx1D + # CuBLAS accumulation equation: y += (y * scale_b) * scale_a + y_partial = y_partial * sw_block + fused_fma( + y_block, + y_partial, + sx_block.expand_as(y_partial).contiguous(), + ) + elif is_a_1d_scaled and not is_b_1d_scaled: + # 1Dx2D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + else: + scale = sx_block * sw_block + fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous()) + + y = y.to(out_dtype) + return y diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index b98966f514..f5c9dc0e96 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -49,6 +49,7 @@ def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1) return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + @classmethod def demunge_scale_shape_from_backend( cls, qtensor_shape: Tuple[int, int], diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py new file mode 100644 index 0000000000..a118c6f81c --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -0,0 +1,832 @@ +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from tests.pytorch.references.blockwise_quantizer_reference import CuBLASScaleMunger +from tests.pytorch.references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm + + +def fp8_blockwise_gemm_supported() -> bool: + return float(torch.version.cuda) >= 12.8 + + +def cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + atol: float = 0.0, + rtol: float = 0.0 +): + if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2: + pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2") + if not (is_x_1d_scaled or is_w_1d_scaled): + pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile") + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + if noise_type == "uniform": + x = torch.rand(x_shape, dtype=torch.float32, device=device) * x_magnitude * 2 - x_magnitude + w = torch.rand(w_shape, dtype=torch.float32, device=device) * w_magnitude * 2 - w_magnitude + elif noise_type == "normal": + x = torch.randn(x_shape, dtype=torch.float32, device=device) * x_magnitude + w = torch.randn(w_shape, dtype=torch.float32, device=device) * w_magnitude + else: + assert False + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude + else: + out = None + + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Reference GEMM + ref_gemm = CuBLASRefBlockwiseGemm() + scale_decoder = CuBLASScaleMunger() + qx_data = ( + qx._columnwise_data.view(dtype=x_dtype) + if x_columnwise + else qx._rowwise_data.view(dtype=x_dtype) + ) + qw_data = ( + qw._columnwise_data.view(dtype=w_dtype) + if w_columnwise + else qw._rowwise_data.view(dtype=w_dtype) + ) + ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv + ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv + y_ref = ref_gemm.qgemm( + qx=qx_data, + qw=qw_data, + out_dtype=out_dtype, + demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape + ), + demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape + ), + quant_tile_shape_x=x_quant_tile_shape, + quant_tile_shape_w=w_quant_tile_shape, + bias=bias, + out=out.clone() if accumulate else None, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + grad = False + gelu = False + gelu_in = None + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + gelu, + gelu_in, + grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y are not the same tensor + assert y_ref is not y, "y_ref and y should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y = torch.where(y.isnan(), torch.zeros_like(y), y) + + # Check + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + +def cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED", + expected_err_cls=RuntimeError +): + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + x = torch.rand(x_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + w = torch.rand(w_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + grad = use_grad + gelu_in = None if not use_gelu else torch.randn((M, N), dtype=out_dtype, device=device) + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + with pytest.raises(expected_err_cls, match=expected_err_msg): + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_in, + grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (128, 128, 128), + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (16, 64, 128), + (128, 160, 128), + (320, 128, 336), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (320, 256, 336), + (256, 512, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (1024, 4096, 1024), + (512, 128, 512), + (768, 128, 768), + (1024, 128, 1024), + (1536, 128, 1536), + (2048, 128, 2048), + (4096, 128, 4096), + (4096, 512, 3072), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_shape_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_bias( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + rtol = 1e-3 + atol = 0.0 + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_bias=True, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["colxrow", "colxcol", "rowxcol"], +) +def test_cublas_gemm_fp8_blockwise_columnwise( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + is_x_columnwise, + is_w_columnwise, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [False], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_split_accumulator_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_bgrad_not_supported_until_tested( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # NOTE: This may work, but until it is tested thoroughly, + # testing that the implementation errors. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=True, + use_bias=True, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no_bias"]) +@pytest.mark.parametrize("use_grad", [True, False], ids=["grad", "no_grad"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_gelu_not_supported_until_tested( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_bias, + use_grad, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # NOTE: This may work, but until it is tested thoroughly, + # testing that the implementation errors. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=use_grad, + use_bias=use_bias, + use_gelu=True, + expected_err_msg=( + "not supported for NVTE_BLOCK_SCALING until further numerical verification" + ), + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_illegal_dtype_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # e5m2 by e5m2 not supported. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (False, False), + ], + ids=["2Dx2D"], +) +def test_illegal_2D_by_2D_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # 2D block quantization by 2D block quantization is not supported. + expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg=expected_err_msg, + ) + + +@pytest.mark.parametrize( + "M, K, N, legalX1d, legalX2d", + [ + # M dim unconstrained when X is 2D. + (255, 128, 256, False, True), + # K must be multiple of 16 + (256, 120, 256, False, False), + # N must be a multiple of 8 + (256, 128, 252, False, False), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (False, True), + (True, True), + ], + ids=["1Dx2D", "2Dx1D", "1Dx1D"], +) +def test_unaligned_shapes( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + legalX1d, + legalX2d, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + legal = legalX1d if is_x_1d_scaled else legalX2d + if not legal: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg="dimension requirement", + ) + else: + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + "uniform", # noise type + 1.0, # x_magnitude + 1.0, # w_magnitude + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f19465c44b..a4a0a2c32d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -55,14 +55,23 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { struct GemmParam { void *A; void *B; + // The layout (e.g. TN to call cublas with) cublasOperation_t transA; cublasOperation_t transB; transformer_engine::DType Atype; transformer_engine::DType Btype; void *A_scale_inv; void *B_scale_inv; + // Element stride for A int lda; + // Element stride for B int ldb; + // major and minor number of elements for the + // storage of A, and B of GemmParam + int a_major_dim; + int a_minor_dim; + int b_major_dim; + int b_minor_dim; GemmParam(cublasOperation_t transA, cublasOperation_t transB) : A(nullptr), @@ -74,27 +83,78 @@ struct GemmParam { A_scale_inv(nullptr), B_scale_inv(nullptr), lda(0), - ldb(0) {} + ldb(0), + a_major_dim(0), + a_minor_dim(0), + b_major_dim(0), + b_minor_dim(0) {} }; GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - const int k, const int lda, const int ldb) { + int A0, int A1, int B0, int B1) { using namespace transformer_engine; - // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. - // Must either force them both into a common block scaling mode or loosen this - // restriction. NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret(transA, transB); - ret.lda = lda; - ret.ldb = ldb; + bool transa_bool = transA == CUBLAS_OP_T; + bool transb_bool = transB == CUBLAS_OP_T; + + int arch = cuda::sm_arch(cuda::current_device()); + if (A.scaling_mode == NVTE_BLOCK_SCALING) { + // For this scaling mode, the quantizer stores + // rowwise data and transposes the data for columnwise + // data so the physical layout is always row major + // and the transA and transB values to pass to cublas + // should always be TN. + + ret.a_major_dim = transa_bool ? A0 : A1; + ret.a_minor_dim = transa_bool ? A1 : A0; + ret.b_major_dim = transb_bool ? B1 : B0; + ret.b_minor_dim = transb_bool ? B0 : B1; + + ret.transA = CUBLAS_OP_T; + ret.transB = CUBLAS_OP_N; + ret.lda = ret.a_minor_dim; + ret.ldb = ret.b_minor_dim; + + NVTE_CHECK(ret.a_minor_dim == ret.b_minor_dim, + "Inner dimension must be equal for NVTE_BLOCK_SCALING Gemm."); + + } else { + // In these scaling modes, the physical layout of + // the tensor will always line up with transA and + // transB, which are passed along to cuBLAS. + // NOTE: There is some logic below that may edit this + // decision for A and B depending on dtype and arch. + const int m = transa_bool ? A0 : A1; + const int k = transa_bool ? A1 : A0; + const int n = transb_bool ? B1 : B0; + ret.a_major_dim = A0; + ret.a_minor_dim = A1; + ret.b_major_dim = B0; + ret.b_minor_dim = B1; + + int lda, ldb; + if (transa_bool && !transb_bool) { // TN + lda = k; + ldb = k; + } else if (!transa_bool && !transb_bool) { // NN + lda = m; + ldb = k; + } else if (!transa_bool && transb_bool) { // NT + lda = m; + ldb = n; + } else { // TT + NVTE_ERROR("TT layout not allowed."); + } + ret.lda = lda; + ret.ldb = ldb; + } - // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases - // or need to be treated as `is_tensor_scaling`. if (is_tensor_scaling(A.scaling_mode)) { ret.A = A.data.dptr; ret.A_scale_inv = A.scale_inv.dptr; @@ -103,14 +163,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; if (is_fp8_dtype(ret.Atype)) { - int arch = cuda::sm_arch(cuda::current_device()); if (arch < 100) { // Hopper and Ada - we need to use columnwise_data and change transA NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); ret.A = A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; ret.A_scale_inv = A.columnwise_scale_inv.dptr; - ret.lda = k; + ret.a_major_dim = A1; + ret.a_minor_dim = A0; + ret.lda = A0; } } } @@ -119,29 +180,63 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (transB == CUBLAS_OP_T) { ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; if (is_fp8_dtype(ret.Btype)) { - int arch = cuda::sm_arch(cuda::current_device()); if (arch < 100) { // Hopper and Ada - we need to use columnwise_data and change transA NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); ret.B = B.columnwise_data.dptr; ret.transB = CUBLAS_OP_N; ret.B_scale_inv = B.columnwise_scale_inv.dptr; - ret.ldb = k; + ret.ldb = B0; + ret.b_major_dim = B1; + ret.b_minor_dim = B0; } } } else { ret.Btype = B.data.dtype; } } else { + // MXF8 scaling or NVTE_BLOCK_SCALING // If not tensor scaling (which includes also high precision types), we need to // use the proper version of data - // We leave the transA/B values as is, since Blackwell supports transposes - ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; - ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; - ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + // For MXF8, we leave the transA/B values as is, since Blackwell supports transposes + // but for NVTE_BLOCK_SCALING, we force transA/B to TN since the quantizers + // store data in that manner and the GEMM requires that layout. + if (A.scaling_mode == NVTE_BLOCK_SCALING) { + if (transA == CUBLAS_OP_T) { + NVTE_CHECK(A.has_data(), "Input A is not suitable for rowwise usage!"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); + } + if (transB == CUBLAS_OP_N) { + NVTE_CHECK(B.has_data(), "Input B is not suitable for rowwise usage!"); + } else { + NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); + } + // Requirements from + // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.a_minor_dim % 16) == 0, + "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. + // Smallest supported CType is 2 bytes in this scaling mode. + NVTE_CHECK((ret.a_major_dim % 8) == 0, + "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Observed this requirement only present for B tensor is 1D quantized. + if (B.block_scaling_dim == 1) { + NVTE_CHECK( + (ret.b_major_dim % 8) == 0, + "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } + NVTE_CHECK((ret.lda % 16) == 0, + "A tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + NVTE_CHECK((ret.ldb % 16) == 0, + "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; } return ret; } @@ -153,18 +248,23 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, - int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, - void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, + const Tensor *inputBias, Tensor *outputPreGelu, int A0, int A1, int B0, int B1, + cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int ldd = m; // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { return; } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, A0, A1, B0, B1); + void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -222,10 +322,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, - param.transA == CUBLAS_OP_N ? k : m, param.lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, - param.transB == CUBLAS_OP_N ? n : k, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( + &Adesc, A_type, param.transA == CUBLAS_OP_N ? param.a_major_dim : param.a_minor_dim, + param.transA == CUBLAS_OP_N ? param.a_minor_dim : param.a_major_dim, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( + &Bdesc, B_type, param.transB == CUBLAS_OP_N ? param.b_minor_dim : param.b_major_dim, + param.transB == CUBLAS_OP_N ? param.b_major_dim : param.b_minor_dim, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); @@ -249,12 +352,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized - // GEMM types. - // Scaling factors. #if CUDA_VERSION >= 12080 - cublasLtMatmulMatrixScale_t scaling_mode; + cublasLtMatmulMatrixScale_t scaling_mode_a; + cublasLtMatmulMatrixScale_t scaling_mode_b; #endif if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { void *A_scale_inverse = param.A_scale_inv; @@ -266,8 +367,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); #if CUDA_VERSION >= 12080 - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; - } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -276,7 +378,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. if (cublasLtGetVersion() <= 120803) { @@ -285,6 +388,30 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } +#if CUDA_VERSION >= 12080 + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + int block_scaling_dim_a = inputA->block_scaling_dim; + int block_scaling_dim_b = inputB->block_scaling_dim; + NVTE_CHECK((block_scaling_dim_a == 1 && block_scaling_dim_b == 1) || + (block_scaling_dim_a == 1 && block_scaling_dim_b == 2) || + (block_scaling_dim_a == 2 && block_scaling_dim_b == 1), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got " + + std::to_string(block_scaling_dim_a) + " x " + + std::to_string(block_scaling_dim_b)); + scaling_mode_a = block_scaling_dim_a == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_b = block_scaling_dim_b == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; +#endif #endif } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + @@ -293,9 +420,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUDA_VERSION >= 12080 NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); #endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output @@ -305,8 +432,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUDA_VERSION >= 12080 - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + // NOTE: In all current cases where FP8 output is supported, the input is + // scaled identically to the output. + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_D_SCALE_MODE, + &scaling_mode_a, sizeof(scaling_mode_a))); #endif // For FP8 output, cuBLAS requires C_type to match bias_type and // be FP16/BF16 @@ -364,6 +494,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } + if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS || + epilogue == CUBLASLT_EPILOGUE_BGRADB), + "Epilogue (gelu fusion) not supported for NVTE_BLOCK_SCALING until further " + "numerical verification."); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); @@ -411,7 +549,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C @@ -474,27 +611,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons const size_t B0 = inputB->flat_first_dim(); const size_t B1 = inputB->flat_last_dim(); - const int m = transa ? A0 : A1; - const int k = transa ? A1 : A0; - const int n = transb ? B1 : B0; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, A0, A1, B0, B1, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); @@ -525,28 +642,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, inputA->data.shape[0], + inputA->data.shape[1], inputB->data.shape[0], inputB->data.shape[1], (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index dae39d82bf..f6b6ae22c2 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -27,7 +27,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -57,7 +57,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (cudnn_backend) { // TODO: add check for GPU ARCH diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 8519fe1b64..c56f9ef407 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -23,7 +23,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -47,7 +47,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index cb2121a457..e8e8b06a4c 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -112,6 +112,13 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } + if (!tensor.attr("_quantizer").is_none()) { + // Some calls to makeTransformerEngineTensor pass a NoneQuantizer. + // The quantizer stores settings like block_scaling_dim that are important. + // and are stored indirectly via the quantizer. + auto tensor_meta_quantizer = CreateQuantizer(tensor.attr("_quantizer")); + tensor_meta_quantizer->set_quantization_params(&ret); + } quantizer->set_quantization_params(&ret); return ret; } From 522ffbe14fed6150547e79e73f3edef005d740b9 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 16:50:56 -0700 Subject: [PATCH 02/21] Update NVTE_BLOCK_SCALING for GEMM. Signed-off-by: Keith Wyss --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index a4a0a2c32d..ff179f3569 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -94,8 +94,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, int A0, int A1, int B0, int B1) { using namespace transformer_engine; - NVTE_CHECK(A.scaling_mode == B.scaling_mode, - "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK(A.scaling_mode == B.scaling_mode || + (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || + (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), + "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret(transA, transB); @@ -104,7 +106,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool transb_bool = transB == CUBLAS_OP_T; int arch = cuda::sm_arch(cuda::current_device()); - if (A.scaling_mode == NVTE_BLOCK_SCALING) { + int a_major_dim; + int b_major_dim; + if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // For this scaling mode, the quantizer stores // rowwise data and transposes the data for columnwise // data so the physical layout is always row major @@ -201,7 +205,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // For MXF8, we leave the transA/B values as is, since Blackwell supports transposes // but for NVTE_BLOCK_SCALING, we force transA/B to TN since the quantizers // store data in that manner and the GEMM requires that layout. - if (A.scaling_mode == NVTE_BLOCK_SCALING) { + if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { if (transA == CUBLAS_OP_T) { NVTE_CHECK(A.has_data(), "Input A is not suitable for rowwise usage!"); } else { @@ -221,7 +225,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK((ret.a_major_dim % 8) == 0, "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Observed this requirement only present for B tensor is 1D quantized. - if (B.block_scaling_dim == 1) { + if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { NVTE_CHECK( (ret.b_major_dim % 8) == 0, "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); @@ -389,8 +393,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, sizeof(dummy_a_vec_stride))); } #if CUDA_VERSION >= 12080 - } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && - (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -399,17 +403,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - int block_scaling_dim_a = inputA->block_scaling_dim; - int block_scaling_dim_b = inputB->block_scaling_dim; - NVTE_CHECK((block_scaling_dim_a == 1 && block_scaling_dim_b == 1) || - (block_scaling_dim_a == 1 && block_scaling_dim_b == 2) || - (block_scaling_dim_a == 2 && block_scaling_dim_b == 1), - "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got " + - std::to_string(block_scaling_dim_a) + " x " + - std::to_string(block_scaling_dim_b)); - scaling_mode_a = block_scaling_dim_a == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; - scaling_mode_b = block_scaling_dim_b == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; #endif #endif From d7e1fce86c4179554e57e15279284d8d119cbb65 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 6 Mar 2025 11:17:27 -0800 Subject: [PATCH 03/21] Gate feature on CUDA 12.9 Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 2 +- transformer_engine/common/gemm/cublaslt_gemm.cu | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index a118c6f81c..3da9e95c17 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -13,7 +13,7 @@ def fp8_blockwise_gemm_supported() -> bool: - return float(torch.version.cuda) >= 12.8 + return float(torch.version.cuda) >= 12.9 def cublas_gemm_fp8_blockwise_case( diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ff179f3569..c51397ac6e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -392,8 +392,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } -#if CUDA_VERSION >= 12080 - } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && +#if CUDA_VERSION >= 12090 +else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); From f212c81dc2cd93da2e363689b90276432eb7d5fa Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 18:18:18 -0700 Subject: [PATCH 04/21] Gemm typo. Signed-off-by: Keith Wyss --- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c51397ac6e..5031db539f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -393,7 +393,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, sizeof(dummy_a_vec_stride))); } #if CUDA_VERSION >= 12090 -else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); From 48b2d57923cb14510e42063c8b566a8c40bd95d4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 17:40:24 -0700 Subject: [PATCH 05/21] Remove unecessary type converter change. Signed-off-by: Keith Wyss --- .../pytorch/csrc/extensions/type_converters.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index e8e8b06a4c..cb2121a457 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -112,13 +112,6 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } - if (!tensor.attr("_quantizer").is_none()) { - // Some calls to makeTransformerEngineTensor pass a NoneQuantizer. - // The quantizer stores settings like block_scaling_dim that are important. - // and are stored indirectly via the quantizer. - auto tensor_meta_quantizer = CreateQuantizer(tensor.attr("_quantizer")); - tensor_meta_quantizer->set_quantization_params(&ret); - } quantizer->set_quantization_params(&ret); return ret; } From 57615893b8f0fe67b2ea5c1f6a8f30160d6699d2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 11 Mar 2025 13:24:05 -0700 Subject: [PATCH 06/21] Reflect epilogue availability and test supported epilogues. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_gemm_exact.py | 130 +++++++++++++++--- .../common/gemm/cublaslt_gemm.cu | 37 ++--- 2 files changed, 134 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 3da9e95c17..c52ced214d 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -34,6 +34,8 @@ def cublas_gemm_fp8_blockwise_case( x_columnwise: bool = False, w_columnwise: bool = False, use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, atol: float = 0.0, rtol: float = 0.0 ): @@ -67,6 +69,7 @@ def cublas_gemm_fp8_blockwise_case( else: out = None + assert not (use_bias and use_grad), "Bias grad not supported by GEMM" # Set quantize_op and quantization parameters x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) @@ -142,9 +145,9 @@ def cublas_gemm_fp8_blockwise_case( transa = True if not w_columnwise else False transb = False if not x_columnwise else True out_quantizer = None - grad = False - gelu = False - gelu_in = None + assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM" + aux_tensor = torch.randn((M, N), dtype=out_dtype, device=device) if use_gelu else None + aux_tensor_ref = aux_tensor.clone() if use_gelu else None bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] # cuBLAS GEMM @@ -160,9 +163,9 @@ def cublas_gemm_fp8_blockwise_case( TE_DType[out_dtype], bias, bias_dtype, - gelu, - gelu_in, - grad, + use_gelu, + aux_tensor, + use_grad, workspace, workspace.shape[0], accumulate, @@ -176,8 +179,25 @@ def cublas_gemm_fp8_blockwise_case( y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) y = torch.where(y.isnan(), torch.zeros_like(y), y) - # Check - torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + if use_gelu: + # Check + if use_grad: + # With use_grad, GEMM should use aux tensor to calculate + # gradient + gelu_ref = tex.dgelu(y_ref, aux_tensor_ref, None) + # TODO: How do we decide whether this is acceptably close? + # Could also try to put the activation inside the reference + # before the output cast to see different tolerances. + torch.testing.assert_close(y, gelu_ref, atol=1e-3, rtol=1e-2) + else: + # aux tensor is pre-gelu aux output. Verify against y_ref. + torch.testing.assert_close(aux_tensor, y_ref, atol=atol, rtol=rtol) + act = torch.nn.GELU() + gelu_ref = act(y_ref) + # gelu_ref = tex.gelu(y_ref, None) + torch.testing.assert_close(y, gelu_ref, atol=atol, rtol=rtol) + else: + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) def cublas_gemm_test_constraint_enforced( @@ -509,6 +529,84 @@ def test_cublas_gemm_fp8_blockwise_columnwise( ) +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "use_grad", + [ + True, + ], + ids=["grad"], +) +def test_cublas_gemm_fp8_gelu( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad, +): + # NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed + # so the epilogue is disabled on the transformer engine side. + if not use_grad and not (is_x_1d_scaled and not is_w_1d_scaled): + pytest.skip( + "CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)." + ) + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_gelu=True, + use_grad=use_grad, + ) + + @pytest.mark.parametrize( "M, K, N", [ @@ -577,7 +675,7 @@ def test_split_accumulator_enforced( ], ids=["1Dx2D", "1Dx1D", "2Dx1D"], ) -def test_bgrad_not_supported_until_tested( +def test_bgrad_not_supported( x_dtype, w_dtype, out_dtype, @@ -589,8 +687,7 @@ def test_bgrad_not_supported_until_tested( is_x_1d_scaled, is_w_1d_scaled, ) -> None: - # NOTE: This may work, but until it is tested thoroughly, - # testing that the implementation errors. + # NOTE: BGRAD epilogue is not supported for fp8. cublas_gemm_test_constraint_enforced( x_dtype, w_dtype, @@ -604,6 +701,7 @@ def test_bgrad_not_supported_until_tested( is_w_1d_scaled, use_grad=True, use_bias=True, + expected_err_msg="Epilogue requested outside of the available", ) @@ -630,7 +728,7 @@ def test_bgrad_not_supported_until_tested( ], ids=["1Dx2D", "1Dx1D", "2Dx1D"], ) -def test_gelu_not_supported_until_tested( +def test_gelu_unsupported_cases_error( x_dtype, w_dtype, out_dtype, @@ -644,8 +742,8 @@ def test_gelu_not_supported_until_tested( is_x_1d_scaled, is_w_1d_scaled, ) -> None: - # NOTE: This may work, but until it is tested thoroughly, - # testing that the implementation errors. + if use_grad and not use_bias: + pytest.skip("DGELU epilogue is supported.") cublas_gemm_test_constraint_enforced( x_dtype, w_dtype, @@ -660,9 +758,7 @@ def test_gelu_not_supported_until_tested( use_grad=use_grad, use_bias=use_bias, use_gelu=True, - expected_err_msg=( - "not supported for NVTE_BLOCK_SCALING until further numerical verification" - ), + expected_err_msg="Epilogue requested outside of the available", ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 5031db539f..419230ecdc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -94,10 +94,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, int A0, int A1, int B0, int B1) { using namespace transformer_engine; - NVTE_CHECK(A.scaling_mode == B.scaling_mode || - (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || - (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), - "Inputs A and B to GEMM need to have compatible scaling modes!"); + NVTE_CHECK( + A.scaling_mode == B.scaling_mode || + (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || + (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), + "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret(transA, transB); @@ -393,8 +394,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, sizeof(dummy_a_vec_stride))); } #if CUDA_VERSION >= 12090 - } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && - (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -404,12 +407,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && - inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); - scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F - : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; - scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F - : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; #endif #endif } else { @@ -493,12 +498,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } - if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && - (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) || + (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) { NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS || - epilogue == CUBLASLT_EPILOGUE_BGRADB), - "Epilogue (gelu fusion) not supported for NVTE_BLOCK_SCALING until further " - "numerical verification."); + epilogue == CUBLASLT_EPILOGUE_DGELU), + "Epilogue requested outside of the available and tested cuBLAS functionality for " + "float8 block scaled GEMM"); } NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, From 07b19b7bfbdd72bf9036cdaeb41532b7faf91eca Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 11 Mar 2025 17:47:46 -0700 Subject: [PATCH 07/21] GEMM simplifications from recipe branch. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_gemm_exact.py | 5 -- .../common/gemm/cublaslt_gemm.cu | 80 ++++++------------- 2 files changed, 23 insertions(+), 62 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index c52ced214d..2b2911c32f 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -15,7 +15,6 @@ def fp8_blockwise_gemm_supported() -> bool: return float(torch.version.cuda) >= 12.9 - def cublas_gemm_fp8_blockwise_case( x_dtype, w_dtype, @@ -432,8 +431,6 @@ def test_cublas_gemm_fp8_blockwise_bias( is_x_1d_scaled, is_w_1d_scaled, ): - rtol = 1e-3 - atol = 0.0 cublas_gemm_fp8_blockwise_case( x_dtype, w_dtype, @@ -449,8 +446,6 @@ def test_cublas_gemm_fp8_blockwise_bias( is_x_1d_scaled, is_w_1d_scaled, use_bias=True, - atol=atol, - rtol=rtol, ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 419230ecdc..ed13fccaef 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -62,16 +62,10 @@ struct GemmParam { transformer_engine::DType Btype; void *A_scale_inv; void *B_scale_inv; - // Element stride for A + // ld are leading dimensions or minor dimensions + // in storage int lda; - // Element stride for B int ldb; - // major and minor number of elements for the - // storage of A, and B of GemmParam - int a_major_dim; - int a_minor_dim; - int b_major_dim; - int b_minor_dim; GemmParam(cublasOperation_t transA, cublasOperation_t transB) : A(nullptr), @@ -83,11 +77,7 @@ struct GemmParam { A_scale_inv(nullptr), B_scale_inv(nullptr), lda(0), - ldb(0), - a_major_dim(0), - a_minor_dim(0), - b_major_dim(0), - b_minor_dim(0) {} + ldb(0) {} }; GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, @@ -116,18 +106,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // and the transA and transB values to pass to cublas // should always be TN. - ret.a_major_dim = transa_bool ? A0 : A1; - ret.a_minor_dim = transa_bool ? A1 : A0; - ret.b_major_dim = transb_bool ? B1 : B0; - ret.b_minor_dim = transb_bool ? B0 : B1; + a_major_dim = transa_bool ? A0 : A1; + b_major_dim = transb_bool ? B1 : B0; + ret.lda = transa_bool ? A1 : A0; + ret.ldb = transb_bool ? B0 : B1; ret.transA = CUBLAS_OP_T; ret.transB = CUBLAS_OP_N; - ret.lda = ret.a_minor_dim; - ret.ldb = ret.b_minor_dim; - NVTE_CHECK(ret.a_minor_dim == ret.b_minor_dim, - "Inner dimension must be equal for NVTE_BLOCK_SCALING Gemm."); + NVTE_CHECK(ret.lda == ret.ldb, "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); } else { // In these scaling modes, the physical layout of @@ -135,29 +122,14 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // transB, which are passed along to cuBLAS. // NOTE: There is some logic below that may edit this // decision for A and B depending on dtype and arch. - const int m = transa_bool ? A0 : A1; - const int k = transa_bool ? A1 : A0; - const int n = transb_bool ? B1 : B0; - ret.a_major_dim = A0; - ret.a_minor_dim = A1; - ret.b_major_dim = B0; - ret.b_minor_dim = B1; - - int lda, ldb; - if (transa_bool && !transb_bool) { // TN - lda = k; - ldb = k; - } else if (!transa_bool && !transb_bool) { // NN - lda = m; - ldb = k; - } else if (!transa_bool && transb_bool) { // NT - lda = m; - ldb = n; - } else { // TT + a_major_dim = A0; + b_major_dim = B0; + ret.lda = A1; + ret.ldb = B1; + + if (transa_bool && transb_bool) { // TT NVTE_ERROR("TT layout not allowed."); } - ret.lda = lda; - ret.ldb = ldb; } if (is_tensor_scaling(A.scaling_mode)) { @@ -174,8 +146,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A = A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; ret.A_scale_inv = A.columnwise_scale_inv.dptr; - ret.a_major_dim = A1; - ret.a_minor_dim = A0; + a_major_dim = A1; ret.lda = A0; } } @@ -191,9 +162,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B = B.columnwise_data.dptr; ret.transB = CUBLAS_OP_N; ret.B_scale_inv = B.columnwise_scale_inv.dptr; + b_major_dim = B1; ret.ldb = B0; - ret.b_major_dim = B1; - ret.b_minor_dim = B0; } } } else { @@ -219,20 +189,18 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } // Requirements from // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.a_minor_dim % 16) == 0, + NVTE_CHECK((ret.lda % 16) == 0, "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. - NVTE_CHECK((ret.a_major_dim % 8) == 0, + NVTE_CHECK((a_major_dim % 8) == 0, "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Observed this requirement only present for B tensor is 1D quantized. if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { NVTE_CHECK( - (ret.b_major_dim % 8) == 0, + (b_major_dim % 8) == 0, "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } - NVTE_CHECK((ret.lda % 16) == 0, - "A tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); NVTE_CHECK((ret.ldb % 16) == 0, "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); } @@ -327,12 +295,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( - &Adesc, A_type, param.transA == CUBLAS_OP_N ? param.a_major_dim : param.a_minor_dim, - param.transA == CUBLAS_OP_N ? param.a_minor_dim : param.a_major_dim, param.lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( - &Bdesc, B_type, param.transB == CUBLAS_OP_N ? param.b_minor_dim : param.b_major_dim, - param.transB == CUBLAS_OP_N ? param.b_major_dim : param.b_minor_dim, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, + param.transA == CUBLAS_OP_N ? k : m, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, + param.transB == CUBLAS_OP_N ? n : k, param.ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); From c4a41b88cb0e628098b6e821bdd4eb91134ae923 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 14 Mar 2025 17:24:32 -0700 Subject: [PATCH 08/21] Format py code. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 2b2911c32f..022f754444 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -15,6 +15,7 @@ def fp8_blockwise_gemm_supported() -> bool: return float(torch.version.cuda) >= 12.9 + def cublas_gemm_fp8_blockwise_case( x_dtype, w_dtype, From 51ed2fb604289900f1b0677477f92ae2b5156c42 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 10:17:29 -0700 Subject: [PATCH 09/21] Update GEMM DGelu tests to match support depending on output dtype. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 022f754444..eedd7056c9 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -540,7 +540,7 @@ def test_cublas_gemm_fp8_blockwise_columnwise( ) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("x_magnitude", [1], ids=str) @pytest.mark.parametrize("w_magnitude", [1], ids=str) @@ -738,8 +738,12 @@ def test_gelu_unsupported_cases_error( is_x_1d_scaled, is_w_1d_scaled, ) -> None: - if use_grad and not use_bias: - pytest.skip("DGELU epilogue is supported.") + if use_grad and not use_bias and out_dtype == torch.bfloat16: + pytest.skip("DGELU epilogue is supported for bfloat16.") + elif use_grad and not use_bias: + expected_err = "an unsupported value or parameter was passed" + else: + expected_err = "Epilogue requested outside of the available" cublas_gemm_test_constraint_enforced( x_dtype, w_dtype, @@ -754,7 +758,7 @@ def test_gelu_unsupported_cases_error( use_grad=use_grad, use_bias=use_bias, use_gelu=True, - expected_err_msg="Epilogue requested outside of the available", + expected_err_msg=expected_err, ) From e7af1404abd43de0b9bcea30378dff5c0f78e212 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 11:29:28 -0700 Subject: [PATCH 10/21] Force pow2Scales in GEMM Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index eedd7056c9..728f84fb2f 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -82,7 +82,7 @@ def cublas_gemm_fp8_blockwise_case( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=x_block_scaling_dim, ) w_quantizer = Float8BlockQuantizer( @@ -90,7 +90,7 @@ def cublas_gemm_fp8_blockwise_case( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=w_block_scaling_dim, ) @@ -252,7 +252,7 @@ def cublas_gemm_test_constraint_enforced( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=x_block_scaling_dim, ) w_quantizer = Float8BlockQuantizer( @@ -260,7 +260,7 @@ def cublas_gemm_test_constraint_enforced( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=w_block_scaling_dim, ) From 596a00912553e457909254f5b4217cba27ad0423 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 11:42:26 -0700 Subject: [PATCH 11/21] Add GEMM test to pytorch test suite. Signed-off-by: Keith Wyss --- qa/L0_pytorch_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eaededc4..1206012195 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail " python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" From 4aa6067ef3707e2f023d7bbc7fe42f165b57812d Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 12:17:59 -0700 Subject: [PATCH 12/21] Add copyright to GEMM test. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 728f84fb2f..61cdef742c 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + import pytest import torch import transformer_engine as te From 758dc4a2cc1c4c3476635f87cd16a0b7687643e2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 18:05:13 -0700 Subject: [PATCH 13/21] Update import for GEMM test. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 61cdef742c..94014d36b5 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -12,8 +12,8 @@ Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from tests.pytorch.references.blockwise_quantizer_reference import CuBLASScaleMunger -from tests.pytorch.references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm +from references.blockwise_quantizer_reference import CuBLASScaleMunger +from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: From 7d5b5d99865501923613f11ba37da30141818a30 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 10:25:20 -0700 Subject: [PATCH 14/21] Add license. Signed-off-by: Keith Wyss --- tests/pytorch/references/blockwise_fp8_gemm_reference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pytorch/references/blockwise_fp8_gemm_reference.py b/tests/pytorch/references/blockwise_fp8_gemm_reference.py index 3487dfb810..5aef986e37 100644 --- a/tests/pytorch/references/blockwise_fp8_gemm_reference.py +++ b/tests/pytorch/references/blockwise_fp8_gemm_reference.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + from typing import Tuple import torch From efdf8e0d963b61c15eba7b29347f02dab2b33488 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 14:21:13 -0700 Subject: [PATCH 15/21] Update test gemm supported predicate. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 94014d36b5..9ddb4b9989 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -12,12 +12,17 @@ Float8BlockQuantizer, Float8BlockwiseQTensor, ) +from transformer_engine.pytorch.utils import get_device_compute_capability from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - return float(torch.version.cuda) >= 12.9 + return ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ) def cublas_gemm_fp8_blockwise_case( From a9f209acca6ebb9441a8c04b4e834da9f6d0ead2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 17:38:37 -0700 Subject: [PATCH 16/21] Use sgemm like interfaces and naming. Signed-off-by: Keith Wyss --- .../common/gemm/cublaslt_gemm.cu | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ed13fccaef..aa88eb9bc4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -82,7 +82,7 @@ struct GemmParam { GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - int A0, int A1, int B0, int B1) { + int m, int k, int n) { using namespace transformer_engine; NVTE_CHECK( A.scaling_mode == B.scaling_mode || @@ -97,8 +97,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool transb_bool = transB == CUBLAS_OP_T; int arch = cuda::sm_arch(cuda::current_device()); - int a_major_dim; - int b_major_dim; + int a_storage_outer_dim; + int b_storage_outer_dim; if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // For this scaling mode, the quantizer stores // rowwise data and transposes the data for columnwise @@ -106,10 +106,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // and the transA and transB values to pass to cublas // should always be TN. - a_major_dim = transa_bool ? A0 : A1; - b_major_dim = transb_bool ? B1 : B0; - ret.lda = transa_bool ? A1 : A0; - ret.ldb = transb_bool ? B0 : B1; + a_storage_outer_dim = m; + b_storage_outer_dim = n; + ret.lda = k; + ret.ldb = k; ret.transA = CUBLAS_OP_T; ret.transB = CUBLAS_OP_N; @@ -122,10 +122,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // transB, which are passed along to cuBLAS. // NOTE: There is some logic below that may edit this // decision for A and B depending on dtype and arch. - a_major_dim = A0; - b_major_dim = B0; - ret.lda = A1; - ret.ldb = B1; + a_storage_outer_dim = transa_bool ? m : k; + b_storage_outer_dim = transb_bool ? k : n; + ret.lda = transa_bool ? k : m; + ret.ldb = transb_bool ? n : k; if (transa_bool && transb_bool) { // TT NVTE_ERROR("TT layout not allowed."); @@ -146,8 +146,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A = A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; ret.A_scale_inv = A.columnwise_scale_inv.dptr; - a_major_dim = A1; - ret.lda = A0; + a_storage_outer_dim = m; + ret.lda = k; } } } @@ -162,8 +162,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B = B.columnwise_data.dptr; ret.transB = CUBLAS_OP_N; ret.B_scale_inv = B.columnwise_scale_inv.dptr; - b_major_dim = B1; - ret.ldb = B0; + b_storage_outer_dim = n; + ret.ldb = k; } } } else { @@ -193,12 +193,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. - NVTE_CHECK((a_major_dim % 8) == 0, + NVTE_CHECK((a_storage_outer_dim % 8) == 0, "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Observed this requirement only present for B tensor is 1D quantized. if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { NVTE_CHECK( - (b_major_dim % 8) == 0, + (b_storage_outer_dim % 8) == 0, "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } NVTE_CHECK((ret.ldb % 16) == 0, @@ -221,14 +221,11 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int A0, int A1, int B0, int B1, + const Tensor *inputBias, Tensor *outputPreGelu, int m, int k, int n, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { - const int m = transa == CUBLAS_OP_T ? A0 : A1; - const int k = transa == CUBLAS_OP_T ? A1 : A0; - const int n = transb == CUBLAS_OP_T ? B1 : B0; const int ldd = m; // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { @@ -236,7 +233,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, A0, A1, B0, B1); + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, k, n); void *C = outputD->data.dptr; void *D = outputD->data.dptr; @@ -359,11 +356,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } -#if CUDA_VERSION >= 12090 } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { +#if CUDA_VERSION >= 12090 float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -381,8 +378,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; -#endif -#endif +#else + NVTE_ERROR("FP8 block scaling requires CUDA 12.9+"); +#endif // CUDA_VERSION >= 12090 +#endif // CUDA_VERSION >= 12080 } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + to_string(inputB->scaling_mode) + "."); @@ -581,7 +580,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons const size_t B0 = inputB->flat_first_dim(); const size_t B1 = inputB->flat_last_dim(); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, A0, A1, B0, B1, + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); @@ -612,8 +615,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, inputA->data.shape[0], - inputA->data.shape[1], inputB->data.shape[0], inputB->data.shape[1], + + const int m = transa == CUBLAS_OP_T ? inputA->data.shape[0] : inputA->data.shape[1]; + const int k = transa == CUBLAS_OP_T ? inputA->data.shape[1] : inputA->data.shape[0]; + const int n = transb == CUBLAS_OP_T ? inputB->data.shape[0] : inputB->data.shape[1]; + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); From 861c8700175ab97cce0c0d5b56793b41f3c502d9 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 17:58:14 -0700 Subject: [PATCH 17/21] Rewrite GEMM comment. Signed-off-by: Keith Wyss --- transformer_engine/common/gemm/cublaslt_gemm.cu | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index aa88eb9bc4..6824eea7a1 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -100,11 +100,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla int a_storage_outer_dim; int b_storage_outer_dim; if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { - // For this scaling mode, the quantizer stores - // rowwise data and transposes the data for columnwise - // data so the physical layout is always row major - // and the transA and transB values to pass to cublas - // should always be TN. + // For this scaling mode, a quantized tensor of the data is stored + // in a row major layout for rowwise data and a quantized tensor of + // the transpose of the data is also stored in row major layout. + // + // cublas will be called with "TN", but Transformer engine uses + // the "TN" parameters to choose between rowwise and columnwise + // row major tensors. a_storage_outer_dim = m; b_storage_outer_dim = n; From ada643897c8b9bbbb64ca46d34de5f1d55390ada Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 18:27:52 -0700 Subject: [PATCH 18/21] MR Feedback. Signed-off-by: Keith Wyss --- transformer_engine/common/gemm/cublaslt_gemm.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 6824eea7a1..483a1380ef 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -582,9 +582,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons const size_t B0 = inputB->flat_first_dim(); const size_t B1 = inputB->flat_last_dim(); - const int m = transa == CUBLAS_OP_T ? A0 : A1; - const int k = transa == CUBLAS_OP_T ? A1 : A0; - const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int m = transa ? A0 : A1; + const int k = transa ? A1 : A0; + const int n = transb ? B1 : B0; cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, @@ -618,9 +618,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - const int m = transa == CUBLAS_OP_T ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa == CUBLAS_OP_T ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb == CUBLAS_OP_T ? inputB->data.shape[0] : inputB->data.shape[1]; + const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; + const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; + const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, From e484269c140010ed5830275f392b53649ec66842 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sun, 6 Apr 2025 01:11:33 +0000 Subject: [PATCH 19/21] Refactor GEMM param canonicalization Configure A and B matrices separately. Have separate code path for each scaling mode. Signed-off-by: Tim Moon --- .../common/gemm/cublaslt_gemm.cu | 300 +++++++++--------- 1 file changed, 153 insertions(+), 147 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 483a1380ef..ed691aa532 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -52,37 +52,36 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); } +/* Parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ struct GemmParam { - void *A; - void *B; - // The layout (e.g. TN to call cublas with) - cublasOperation_t transA; - cublasOperation_t transB; - transformer_engine::DType Atype; - transformer_engine::DType Btype; - void *A_scale_inv; - void *B_scale_inv; - // ld are leading dimensions or minor dimensions - // in storage - int lda; - int ldb; - - GemmParam(cublasOperation_t transA, cublasOperation_t transB) - : A(nullptr), - B(nullptr), - transA(transA), - transB(transB), - Atype(transformer_engine::DType::kNumTypes), - Btype(transformer_engine::DType::kNumTypes), - A_scale_inv(nullptr), - B_scale_inv(nullptr), - lda(0), - ldb(0) {} + void *A = nullptr; + void *B = nullptr; + cublasOperation_t transA = CUBLAS_OP_N; + cublasOperation_t transB = CUBLAS_OP_N; + transformer_engine::DType Atype = transformer_engine::DType::kNumTypes; + transformer_engine::DType Btype = transformer_engine::DType::kNumTypes; + void *A_scale_inv = nullptr; + void *B_scale_inv = nullptr; + int lda = 0; // A column strides + int ldb = 0; // B column strides }; +/* Populate parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - int m, int k, int n) { + int m, int n, int k) { using namespace transformer_engine; NVTE_CHECK( A.scaling_mode == B.scaling_mode || @@ -91,128 +90,135 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); - GemmParam ret(transA, transB); + GemmParam ret; + + // Device compute capability + const int arch = cuda::sm_arch(); + // Transpose mode with column-major ordering bool transa_bool = transA == CUBLAS_OP_T; bool transb_bool = transB == CUBLAS_OP_T; - int arch = cuda::sm_arch(cuda::current_device()); - int a_storage_outer_dim; - int b_storage_outer_dim; - if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { - // For this scaling mode, a quantized tensor of the data is stored - // in a row major layout for rowwise data and a quantized tensor of - // the transpose of the data is also stored in row major layout. - // - // cublas will be called with "TN", but Transformer engine uses - // the "TN" parameters to choose between rowwise and columnwise - // row major tensors. - - a_storage_outer_dim = m; - b_storage_outer_dim = n; - ret.lda = k; - ret.ldb = k; - - ret.transA = CUBLAS_OP_T; - ret.transB = CUBLAS_OP_N; - - NVTE_CHECK(ret.lda == ret.ldb, "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); - - } else { - // In these scaling modes, the physical layout of - // the tensor will always line up with transA and - // transB, which are passed along to cuBLAS. - // NOTE: There is some logic below that may edit this - // decision for A and B depending on dtype and arch. - a_storage_outer_dim = transa_bool ? m : k; - b_storage_outer_dim = transb_bool ? k : n; - ret.lda = transa_bool ? k : m; - ret.ldb = transb_bool ? n : k; - - if (transa_bool && transb_bool) { // TT - NVTE_ERROR("TT layout not allowed."); - } - } - + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.A = A.data.dptr; + ret.transA = transA; + ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - if (transA == CUBLAS_OP_T) { - ret.Atype = A.data.dtype; - } else { - ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; - if (is_fp8_dtype(ret.Atype)) { - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); - ret.A = A.columnwise_data.dptr; - ret.transA = CUBLAS_OP_T; - ret.A_scale_inv = A.columnwise_scale_inv.dptr; - a_storage_outer_dim = m; - ret.lda = k; - } + ret.lda = transa_bool ? k : m; + if (arch < 100 && !transa_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } - ret.B = B.data.dptr; - ret.B_scale_inv = B.scale_inv.dptr; - if (transB == CUBLAS_OP_T) { - ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; - if (is_fp8_dtype(ret.Btype)) { - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); - ret.B = B.columnwise_data.dptr; - ret.transB = CUBLAS_OP_N; - ret.B_scale_inv = B.columnwise_scale_inv.dptr; - b_storage_outer_dim = n; - ret.ldb = k; - } - } + } else if (is_mxfp_scaling(A.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - ret.Btype = B.data.dtype; + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); } - } else { - // MXF8 scaling or NVTE_BLOCK_SCALING - // If not tensor scaling (which includes also high precision types), we need to - // use the proper version of data - // For MXF8, we leave the transA/B values as is, since Blackwell supports transposes - // but for NVTE_BLOCK_SCALING, we force transA/B to TN since the quantizers - // store data in that manner and the GEMM requires that layout. - if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { - if (transA == CUBLAS_OP_T) { - NVTE_CHECK(A.has_data(), "Input A is not suitable for rowwise usage!"); - } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); - } - if (transB == CUBLAS_OP_N) { - NVTE_CHECK(B.has_data(), "Input B is not suitable for rowwise usage!"); - } else { - NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); - } - // Requirements from - // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.lda % 16) == 0, - "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); - // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. - // Smallest supported CType is 2 bytes in this scaling mode. - NVTE_CHECK((a_storage_outer_dim % 8) == 0, - "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); - // Observed this requirement only present for B tensor is 1D quantized. - if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { - NVTE_CHECK( - (b_storage_outer_dim % 8) == 0, - "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); - } - NVTE_CHECK((ret.ldb % 16) == 0, - "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = transA; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = m; + } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); } ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.lda % 16) == 0, + "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. + // Smallest supported CType is 2 bytes in this scaling mode. + NVTE_CHECK((m % 8) == 0, + "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } else { + NVTE_ERROR("A has unsupported scaling mode"); + } + + // Configure B matrix + if (is_tensor_scaling(B.scaling_mode)) { + // Unscaled or FP8 tensor scaling + ret.B = B.data.dptr; + ret.transB = transB; + ret.Btype = B.data.dtype; + ret.B_scale_inv = B.scale_inv.dptr; + ret.ldb = transb_bool ? n : k; + if (arch < 100 && transb_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); + } + } + } else if (is_mxfp_scaling(B.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = transB; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + + // Requirements from + // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.ldb % 16) == 0, + "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { + // Observed this requirement only present for B tensor is 1D quantized. + NVTE_CHECK((n % 8) == 0, + "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } + } else { + NVTE_ERROR("B has unsupported scaling mode"); } + return ret; } @@ -223,19 +229,33 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int m, int k, int n, + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + // Tensor dims in row-major order + const int A0 = inputA->flat_first_dim(); + const int A1 = inputA->flat_last_dim(); + const int B0 = inputB->flat_first_dim(); + const int B1 = inputB->flat_last_dim(); + + // GEMM dims in column-major order + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, + "GEMM inputs have incompatible dimensions (A is ", + A0, "x", A1, ", B is ", B0, "x", B1, ")"); const int ldd = m; + // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { return; } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, k, n); + const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); void *C = outputD->data.dptr; void *D = outputD->data.dptr; @@ -577,16 +597,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const size_t A0 = inputA->flat_first_dim(); - const size_t A1 = inputA->flat_last_dim(); - const size_t B0 = inputB->flat_first_dim(); - const size_t B1 = inputB->flat_last_dim(); - - const int m = transa ? A0 : A1; - const int k = transa ? A1 : A0; - const int n = transb ? B1 : B0; - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); @@ -617,12 +628,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); From 9f0707e5f39ba9f0b924e63c0c08f9e993c04ba1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 6 Apr 2025 01:12:09 +0000 Subject: [PATCH 20/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ed691aa532..6fe3539257 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -229,11 +229,10 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, - cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, cudaStream_t stream) { + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); const int A1 = inputA->flat_last_dim(); @@ -245,8 +244,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int n = transb == CUBLAS_OP_T ? B1 : B0; const int k = transa == CUBLAS_OP_T ? A1 : A0; NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, - "GEMM inputs have incompatible dimensions (A is ", - A0, "x", A1, ", B is ", B0, "x", B1, ")"); + "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, + ")"); const int ldd = m; // Return immediately if GEMM is trivial @@ -597,10 +596,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, 0, 0, false, nullptr, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -628,10 +626,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, + inputCounter, stream); } void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, From f3123cf37f876588d385fb86ebc96375fbb6a4a4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 7 Apr 2025 13:19:38 -0700 Subject: [PATCH 21/21] Prune number of tests. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_gemm_exact.py | 82 ++++++++++++++----- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 9ddb4b9989..9a1cfa2db8 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -326,25 +326,68 @@ def cublas_gemm_test_constraint_enforced( (128, 128, 128), (256, 128, 256), # non 128x128 divisible input shapes - (16, 128, 128), - (16, 64, 128), - (128, 160, 128), (320, 128, 336), (320, 64, 336), # k > 128 (256, 256, 256), (320, 256, 336), - (256, 512, 256), - (256, 1024, 256), - (1024, 1024, 1024), (1024, 4096, 1024), - (512, 128, 512), - (768, 128, 768), - (1024, 128, 1024), - (1536, 128, 1536), - (2048, 128, 2048), - (4096, 128, 4096), - (4096, 512, 3072), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_shape_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + (320, 256, 336), ], ) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @@ -364,7 +407,7 @@ def cublas_gemm_test_constraint_enforced( ], ids=["1Dx2D", "1Dx1D", "2Dx1D"], ) -def test_cublas_gemm_fp8_blockwise_shape_varying( +def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying( x_dtype, w_dtype, out_dtype, @@ -402,18 +445,16 @@ def test_cublas_gemm_fp8_blockwise_shape_varying( # k = 128 (256, 128, 256), # non 128x128 divisible input shapes - (16, 128, 128), (320, 64, 336), # k > 128 (256, 256, 256), - (4096, 128, 4096), ], ) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str) -@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str) @pytest.mark.parametrize("w_magnitude", [1], ids=str) @pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) @pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) @@ -468,7 +509,6 @@ def test_cublas_gemm_fp8_blockwise_bias( (16, 128, 128), (320, 64, 336), # k > 128 - (256, 256, 256), (4096, 128, 4096), ], ) @@ -540,15 +580,13 @@ def test_cublas_gemm_fp8_blockwise_columnwise( # k = 128 (256, 128, 256), # non 128x128 divisible input shapes - (16, 128, 128), (320, 64, 336), # k > 128 (256, 256, 256), - (4096, 128, 4096), ], ) -@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("x_magnitude", [1], ids=str)