diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index e5faa688ce..fe4ae2d264 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -501,6 +501,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 2u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. @@ -552,6 +558,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 1u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index ec23cfe8c5..bdc73519be 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -19,7 +20,8 @@ def fp8_blockwise_gemm_supported() -> bool: supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() - return supported + emulated = get_device_compute_capability() >= (10, 0) + return supported and not emulated def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 858ce73b6b..51e0d1ec9b 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8BlockScaling +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -32,6 +33,7 @@ if tensor_dump_dir_env is not None: TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() +recipe_emulated = get_device_compute_capability() >= (10, 0) class GetRecipes: @@ -218,6 +220,12 @@ def check_quantization_block_tiling_versus_reference( pow_2_scales: bool, tile_size: Tuple[int, int], ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): block_scaling_dim = 1 @@ -409,6 +417,12 @@ def test_quantization_block_tiling_extrema_versus_reference( tile_size: Tuple[int, int], extrema_high: bool, ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + # This test runs a single tile through a quantizer as a way to test # branch coverage of scale computation. te_dtype = TE_DType[quant_dtype] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e0fe3c04a6..92b57897de 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -127,6 +127,7 @@ list(APPEND transformer_engine_SOURCES util/multi_stream.cpp util/rtc.cpp swizzle/swizzle.cu + swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 079feb4a7d..624e71d1e3 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM + * + * \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv. + * \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it + * not natively supported by cublasLt on architectures other than Hopper. + + * Requirements: + * - input is an FP8 block scaling tensor + * - input has rowwise usage + * - input.scale_inv is in GEMM_READY format + * - output is an MXFP8 tensor + * - output has rowwise usage + * - output.scale_inv has appropriate shape + * */ +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu new file mode 100644 index 0000000000..4be85474af --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -0,0 +1,321 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace { +constexpr uint32_t WARP_SIZE = 32; +} // namespace +namespace swizzle_kernel_1d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where +// each thread stores a single row (of four bytes). +// Example: +// lane0.row = 0x00010203 +// lane1.row = 0x04050607 +// lane2.row = 0x08090a0b +// lane3.row = 0x0c0d0e0f +// Becomes: +// lane0.row = 0x0004080c +// lane1.row = 0x0105090d +// lane2.row = 0x02060a0e +// lane3.row = 0x03070b0f +uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row, + const uint32_t lane, + const uint32_t active_mask) { + using cu = const uint32_t; + + // Threads operate in groups of 4, and each thread stores 4 bytes at a time. + // The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes + // until we have transposed the 4x4 matrix. + cu m_0123_4567_89ab_cdef = row; + cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4); + cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240); + cu m_5173_1537_d9fb_9dbf = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x3715); + cu m_0426_1537_8cae_9dbf = (lane & 1) ? m_5173_1537_d9fb_9dbf : m_0426_4062_8cae_c8ea; + cu m_8cae_9dbf_0426_1537 = __shfl_xor_sync(active_mask, m_0426_1537_8cae_9dbf, 2, 4); + cu m_048c_159d_8c04_9d15 = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x5410); + cu m_ae26_bf37_26ae_37bf = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x3276); + cu m_048c_159d_26ae_37bf = (lane & 2) ? m_ae26_bf37_26ae_37bf : m_048c_159d_8c04_9d15; + + return m_048c_159d_26ae_37bf; +} + +// Expands a uint32_t to a uint4 by duplicating each byte four times. +// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404} +uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) { + return {__byte_perm(x, 0, 0x0000), __byte_perm(x, 0, 0x1111), __byte_perm(x, 0, 0x2222), + __byte_perm(x, 0, 0x3333)}; +} + +// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data +// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors. +struct no_oob_tag_t {}; +constexpr no_oob_tag_t NO_OOB_TAG; + +template +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride, + OOBT first_oob) { + // resolve kernel variant + constexpr bool no_oob = std::is_same_v; + static_assert(no_oob || std::is_same_v); + + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_x; + const uint32_t in_tile_x = out_tile_y; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factors for this lane's initial four 1x128 tiles + uint4 sf; + if constexpr (no_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + if ((out_tile_y < tiles_y - 1) || lane < first_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + sf = uint4{0, 0, 0, 0}; + } + } + + // pack the exponent bits of the scaling factors + uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); + + // partially swizzle the scaling factors + constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches + const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); + packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx); + + // transpose 4x4 matrices of scaling factors + packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK); + + // broadcast the scaling factors for sixteen 1x32 tiles + sf = broadcast_uint32_t_to_uint4(packed_exponents); + + // store them cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(uint4)), "Input scaling factor pointer must be aligned to ", + alignof(uint4), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales + // and a 128x4 tile in the output scales. The input scales are in transposed order. + const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + const uint32_t first_oob = (input_scale_inv_cols % 128) / 4; + + if (first_oob == 0) { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, NO_OOB_TAG); + } else { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, first_oob); + } +} +} // namespace swizzle_kernel_1d +namespace swizzle_kernel_2d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_y; + const uint32_t in_tile_x = out_tile_x; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = sizeof(float); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factor for this warp's 128x128 tile + uint32_t sf = *reinterpret_cast(warp_src); + + // broadcast it to four scaling factors for 1x32 tiles + sf = (sf << 1) | (sf >> 7); + sf = sf | (sf >> 16); + + // broadcast it to sixteen scaling factors for 1x32 tiles + const uint4 sf4{sf, sf, sf, sf}; + + // store it cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf4; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(float)), "Input scaling factor pointer must be aligned to ", + alignof(float), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales + // and a 128x4 tile in the output scales. + const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); +} +} // namespace swizzle_kernel_2d + +void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* output, + cudaStream_t stream) { + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "block_scaling_scaling_factor_input"); + CheckInputTensor(*output, "mxfp8_scaling_factor_output"); + + const NVTEScalingMode scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING, + "Output tensor must be an mxfp8 tensor"); + + NVTE_CHECK(input->data.dtype == transformer_engine::DType::kFloat8E4M3 || + input->data.dtype == transformer_engine::DType::kFloat8E5M2, + "Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8"); + NVTE_CHECK(output->data.dtype == input->data.dtype, + "Output data must have the same dtype as input data"); + NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32, "Input must have FP32 scaling factors"); + NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0, + "Output must have E8M0 scaling factors"); + + NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data"); + NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input"); + NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors"); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Output must have rowwise scaling factors"); + + NVTE_CHECK(input->data.shape.size() == 2, "Input data must be a matrix"); + NVTE_CHECK(output->data.shape == input->data.shape, + "Output data must have the same shape as input data"); + NVTE_CHECK(input->scale_inv.shape.size() == 2, "Input scaling factors must be a matrix"); + NVTE_CHECK(output->scale_inv.shape.size() == 2, "Output scaling factors must be a matrix"); + + const size_t data_rows = input->data.shape[0]; + const size_t data_cols = input->data.shape[1]; + const size_t input_scale_inv_rows = input->scale_inv.shape[0]; + const size_t input_scale_inv_cols = input->scale_inv.shape[1]; + const size_t output_scale_inv_rows = output->scale_inv.shape[0]; + const size_t output_scale_inv_cols = output->scale_inv.shape[1]; + + NVTE_CHECK(output_scale_inv_rows == DIVUP(data_rows, 128) * 128, + "Expected the output scaling factor matrix to have ", + DIVUP(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows, + " rows instead."); + NVTE_CHECK(output_scale_inv_cols == DIVUP(data_cols, 128) * 4, + "Expected the output scaling factor matrix to have ", + DIVUP(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols, + " columns instead."); + + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_cols, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_cols, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_rows, 4) * 4, + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 4) * 4, + " columns, but it has ", input_scale_inv_cols, " columns instead."); + + swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } else { // scaling_mode == NVTE_BLOCK_SCALING_2D + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_rows, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_cols, 512) * 4, + "Expected the input scaling factor matrix to have ", + DIVUP(data_cols, 512) * 4, " columns, but it has ", input_scale_inv_cols, + " columns instead."); + + swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } +} + +} // namespace transformer_engine + +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors); + using namespace transformer_engine; + swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input), + convertNVTETensorCheck(output), stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index f49fe239aa..35e8b683ad 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -64,6 +64,10 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; + case NVTE_BLOCK_SCALING_1D: + return "NVTE_BLOCK_SCALING_1D"; + case NVTE_BLOCK_SCALING_2D: + return "NVTE_BLOCK_SCALING_2D"; case NVTE_NVFP4_1D_SCALING: return "NVTE_NVFP4_1D_SCALING"; case NVTE_INVALID_SCALING: diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index c3f085b877..661cf339ae 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,6 +14,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -485,6 +486,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index d38bf79963..fcf7a151c3 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,6 +17,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -529,6 +530,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; size_t num_rows = 1; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 1364597519..15404ad9a6 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -104,6 +104,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const bool low_precision = detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype()); + const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D; // Check tensor dimensions const auto& A_shape = A_tensor.shape(); @@ -235,6 +239,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { + // Convert tensors to mxfp8 and swizzle their scaling factors + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa))); + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb))); + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + if (comm_overlap) { // Prepare extra output tensor TensorWrapper extra_output_tensor; @@ -379,15 +396,6 @@ std::optional> te_general_grouped_gemm( std::vector bias, DType bias_type, bool single_output, std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { - std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, - te_pre_gelu_out_vector, te_workspace_vector; - std::vector te_A_wrappers, te_B_wrappers, wrappers; - std::vector D_vectors; - - auto none = py::none(); - - std::vector single_output_begins; - std::vector single_output_ends; if (single_output && D == std::nullopt) { NVTE_ERROR("not implemented, D should be allocated for single output case."); } @@ -397,6 +405,10 @@ std::optional> te_general_grouped_gemm( output_data_ptr = (*D)[0].data_ptr(); } + const auto none = py::none(); + std::vector te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers, + te_pre_gelu_out_wrappers; + std::vector D_vectors; for (size_t i = 0; i < A.size(); i++) { auto te_A = makeTransformerEngineTensor(A[i], none); auto te_B = makeTransformerEngineTensor(B[i], none); @@ -462,29 +474,72 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); - te_A_vector.emplace_back(te_A.data()); - te_B_vector.emplace_back(te_B.data()); - te_D_vector.emplace_back(te_D.data()); - te_bias_vector.emplace_back(te_bias.data()); - te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - te_A_wrappers.emplace_back(std::move(te_A)); te_B_wrappers.emplace_back(std::move(te_B)); - wrappers.emplace_back(std::move(te_D)); - wrappers.emplace_back(std::move(te_bias)); - wrappers.emplace_back(std::move(te_pre_gelu_out)); + te_D_wrappers.emplace_back(std::move(te_D)); + te_bias_wrappers.emplace_back(std::move(te_bias)); + te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } + // Keep the swizzled scaling factor tensors alive during the GEMM. + std::vector> swizzled_scale_inverses_list; + // Optionally swizzle the scaling factors - // Keep the swizzled scaling factor tensors alive during the GEMMs. - auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); - auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa)); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb)); + + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (transformer_engine::cuda::sm_arch() >= 100) { + // Check if is using FP8 block scaling + bool exists_tensor_using_fp8_block_scaling = false; + bool exists_tensor_not_using_fp8_block_scaling = false; + for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) { + for (const TensorWrapper& tensor : *tensor_wrappers) { + const NVTEScalingMode scaling_mode = tensor.scaling_mode(); + if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) + exists_tensor_using_fp8_block_scaling = true; + else + exists_tensor_not_using_fp8_block_scaling = true; + } + } + if (exists_tensor_using_fp8_block_scaling) { + NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling, + "Either all tensors or no tensor must be FP8 block scaling tensors"); + // Convert tensors to mxfp8 and swizzle their scaling factors + for (TensorWrapper& A_tensor : te_A_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)); + } + for (TensorWrapper& B_tensor : te_B_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)); + } + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + } + + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector; + for (size_t i = 0; i < te_A_wrappers.size(); i++) { + te_A_vector.emplace_back(te_A_wrappers[i].data()); + te_B_vector.emplace_back(te_B_wrappers[i].data()); + te_D_vector.emplace_back(te_D_wrappers[i].data()); + te_bias_vector.emplace_back(te_bias_wrappers[i].data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data()); + } + std::vector te_workspace_vector; + std::vector te_workspace_wrappers; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); - wrappers.emplace_back(std::move(wsp)); + te_workspace_wrappers.emplace_back(std::move(wsp)); } // For now, we only have multi-stream cublas backend. diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 3bb6be715d..ffba5b2763 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -7,6 +7,7 @@ #include "util.h" #include "common.h" +#include "common/common.h" std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { @@ -177,3 +178,72 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } + +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, + bool rowwise) { + using namespace transformer_engine::pytorch; + using transformer_engine::DIVUP; + + // Check input tensor + const NVTEScalingMode scaling_mode = input.scaling_mode(); + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + + // Get tensor data + NVTEBasicTensor data; + size_t data_flat_first_dim = 1; + size_t data_flat_last_dim = 1; + if (rowwise) { + data = input.get_rowwise_data(); + for (int i = 0; i < data.shape.ndim - 1; ++i) { + data_flat_first_dim *= data.shape.data[i]; + } + data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; + } else { + data = input.get_columnwise_data(); + data_flat_first_dim = data.shape.data[0]; + for (int i = 1; i < data.shape.ndim; ++i) { + data_flat_last_dim *= data.shape.data[i]; + } + } + NVTEShape data_shape{}; + data_shape.data[0] = data_flat_first_dim; + data_shape.data[1] = data_flat_last_dim; + data_shape.ndim = 2; + + // Recreate input tensor with rowwise usage + transformer_engine::TensorWrapper input_cu(scaling_mode); + input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + const NVTEBasicTensor scale_inv = + rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); + input_cu.set_rowwise_scale_inv( + scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); + + // Create output tensor + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + // Output swizzled mxfp8 scaling factor dimensions + const size_t swizzled_scale_inv_first_dim = DIVUP(data_flat_first_dim, 128) * 128; + const size_t swizzled_scale_inv_last_dim = DIVUP(data_flat_last_dim, 128) * 4; + // Allocate memory for swizzled mxfp8 scaling factors + const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + at::Tensor swizzled_scale_inv = at::empty( + std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options); + // Set rowwise scaling factors on output + void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + NVTEShape swizzled_scale_inv_shape{}; + swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; + swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; + swizzled_scale_inv_shape.ndim = 2; + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + swizzled_scale_inv_shape); + + // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format + nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor + // for it to be kept alive during the GEMM + input = std::move(output_cu); + return swizzled_scale_inv; +} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 4b26860967..57eee86d2a 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -27,4 +27,16 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); +/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. + * + * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid + * transposing it in memory. Due to differences in how block scaling and mxfp8 store data, + * this requires the calling code to treat the output tensor as having been tranposed in this case. + * + * Returns the swizzled scaling factor of the converted mxfp8 tensor. + * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + */ +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, + bool rowwise); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c001e8e79a..51fbb50c4c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1015,12 +1015,8 @@ def _post_process_fp8_blockwise_gather( if out._is_gemm_ready_format(): return out - needs_columnwise_data_transpose = ( - quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() - ) - need_rowwise_scale_transpose = ( - quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported() - ) + needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage + need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 # columnwise compact format means doing 128x1 quantization of it diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index a62e10bc57..bfe241f81b 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -64,13 +64,12 @@ def check_nvfp4_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" - if ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ): + if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" - return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + return ( + False, + "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", + ) def check_recipe_support(recipe: Recipe) -> None: