From b51e4a33fff97f9718350a47b0886d59433a35da Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 20 Aug 2025 23:10:34 +0000 Subject: [PATCH 01/17] Update to_string(NVTEScalingMode) to include block scaling Signed-off-by: Jan Bielak --- transformer_engine/common/transformer_engine.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 55654989a7..24be6475b8 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -63,6 +63,10 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; + case NVTE_BLOCK_SCALING_1D: + return "NVTE_BLOCK_SCALING_1D"; + case NVTE_BLOCK_SCALING_2D: + return "NVTE_BLOCK_SCALING_2D"; case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; case NVTE_INVALID_SCALING: From 3f106e346ff32f65b92d2adeb3cdcdf30070a38d Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Thu, 28 Aug 2025 22:40:32 +0000 Subject: [PATCH 02/17] Add `nvte_swizzle_block_scaling_to_mxfp8_scaling_factors` Signed-off-by: Jan Bielak --- transformer_engine/common/CMakeLists.txt | 1 + .../include/transformer_engine/swizzle.h | 20 ++ .../common/swizzle/swizzle_block_scaling.cu | 284 ++++++++++++++++++ 3 files changed, 305 insertions(+) create mode 100644 transformer_engine/common/swizzle/swizzle_block_scaling.cu diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cb9f13b899..983e768f8f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -97,6 +97,7 @@ list(APPEND transformer_engine_SOURCES util/multi_stream.cpp util/rtc.cpp swizzle/swizzle.cu + swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 079feb4a7d..624e71d1e3 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM + * + * \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv. + * \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it + * not natively supported by cublasLt on architectures other than Hopper. + + * Requirements: + * - input is an FP8 block scaling tensor + * - input has rowwise usage + * - input.scale_inv is in GEMM_READY format + * - output is an MXFP8 tensor + * - output has rowwise usage + * - output.scale_inv has appropriate shape + * */ +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu new file mode 100644 index 0000000000..34b1f28650 --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -0,0 +1,284 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace { +constexpr uint32_t WARP_SIZE = 32; +} // namespace +namespace swizzle_kernel_1d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where +// each thread stores a single row (of four bytes). +// Example: +// lane0.row = 0x00010203 +// lane1.row = 0x04050607 +// lane2.row = 0x08090a0b +// lane3.row = 0x0c0d0e0f +// Becomes: +// lane0.row = 0x0004080c +// lane1.row = 0x0105090d +// lane2.row = 0x02060a0e +// lane3.row = 0x03070b0f +uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row, + const uint32_t lane, + const uint32_t active_mask) { + using cu = const uint32_t; + + cu m_0123_4567_89ab_cdef = row; + cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4); + cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240); + cu m_5173_1537_d9fb_9dbf = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x3715); + cu m_0426_1537_8cae_9dbf = (lane & 1) ? m_5173_1537_d9fb_9dbf : m_0426_4062_8cae_c8ea; + cu m_8cae_9dbf_0426_1537 = __shfl_xor_sync(active_mask, m_0426_1537_8cae_9dbf, 2, 4); + cu m_048c_159d_8c04_9d15 = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x5410); + cu m_ae26_bf37_26ae_37bf = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x3276); + cu m_048c_159d_26ae_37bf = (lane & 2) ? m_ae26_bf37_26ae_37bf : m_048c_159d_8c04_9d15; + + return m_048c_159d_26ae_37bf; +} + +// Expands a uint32_t to a uint4 by duplicating each byte four times. +// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404} +uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) { + return {__byte_perm(x, 0, 0x0000), __byte_perm(x, 0, 0x1111), __byte_perm(x, 0, 0x2222), + __byte_perm(x, 0, 0x3333)}; +} + +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_x; + const uint32_t in_tile_x = out_tile_y; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factors for this lane's initial four 1x128 tiles + const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); + uint4 sf = reinterpret_cast(warp_src)[lane_load_idx]; + + // pack the exponent bits of the scaling factors + uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); + + // transpose 4x4 matrices of scaling factors + constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches + packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK); + + // broadcast the scaling factors for sixteen 1x32 tiles + sf = broadcast_uint32_t_to_uint4(packed_exponents); + + // store them cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(uint4)), "Input scaling factor pointer must be aligned to ", + alignof(uint4), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + NVTE_CHECK(data_rows % 128 == 0, + "Input scaling factors have to be available for full 128x128 tiles"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + + const uint32_t out_y_stride = tiles_x * 512; + + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); +} +} // namespace swizzle_kernel_1d +namespace swizzle_kernel_2d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_y; + const uint32_t in_tile_x = out_tile_x; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = sizeof(float); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factor for this warp's 128x128 tile + uint32_t sf = *reinterpret_cast(warp_src); + + // broadcast it to four scaling factors for 1x32 tiles + sf = (sf << 1) | (sf >> 7); + sf = sf | (sf >> 16); + + // broadcast it to sixteen scaling factors for 1x32 tiles + const uint4 sf4{sf, sf, sf, sf}; + + // store it cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf4; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(float)), "Input scaling factor pointer must be aligned to ", + alignof(float), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + + const uint32_t out_y_stride = tiles_x * 512; + + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); +} +} // namespace swizzle_kernel_2d + +void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* output, + cudaStream_t stream) { + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "block_scaling_scaling_factor_input"); + CheckInputTensor(*output, "mxfp8_scaling_factor_output"); + + const NVTEScalingMode scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING, + "Output tensor must be an mxfp8 tensor"); + + NVTE_CHECK(input->data.dtype == transformer_engine::DType::kFloat8E4M3 || + input->data.dtype == transformer_engine::DType::kFloat8E5M2, + "Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8"); + NVTE_CHECK(output->data.dtype == input->data.dtype, + "Output data must have the same dtype as input data"); + NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32, "Input must have FP32 scaling factors"); + NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0, + "Output must have E8M0 scaling factors"); + + NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data"); + NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input"); + NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors"); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Output must have rowwise scaling factors"); + + NVTE_CHECK(input->data.shape.size() == 2, "Input data must be a matrix"); + NVTE_CHECK(output->data.shape == input->data.shape, + "Output data must have the same shape as input data"); + NVTE_CHECK(input->scale_inv.shape.size() == 2, "Input scaling factors must be a matrix"); + NVTE_CHECK(output->scale_inv.shape.size() == 2, "Output scaling factors must be a matrix"); + + const size_t data_rows = input->data.shape[0]; + const size_t data_cols = input->data.shape[1]; + const size_t input_scale_inv_rows = input->scale_inv.shape[0]; + const size_t input_scale_inv_cols = input->scale_inv.shape[1]; + const size_t output_scale_inv_rows = output->scale_inv.shape[0]; + const size_t output_scale_inv_cols = output->scale_inv.shape[1]; + + NVTE_CHECK(output_scale_inv_rows == DIVUP(data_rows, 128) * 128, + "Expected the output scaling factor matrix to have ", + DIVUP(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows, + " rows instead."); + NVTE_CHECK(output_scale_inv_cols == DIVUP(data_cols, 128) * 4, + "Expected the output scaling factor matrix to have ", + DIVUP(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols, + " columns instead."); + + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_cols, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_cols, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_rows, 4) * 4, + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 4) * 4, + "columns, but it has ", input_scale_inv_cols, " columns instead."); + + swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } else { // scaling_mode == NVTE_BLOCK_SCALING_2D + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_rows, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_cols, 512) * 4, + "Expected the input scaling factor matrix to have ", + DIVUP(data_cols, 512) * 4, "columns, but it has ", input_scale_inv_cols, + " columns instead."); + + swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } +} + +} // namespace transformer_engine + +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors); + using namespace transformer_engine; + swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input), + convertNVTETensorCheck(output), stream); +} From 072e9cbfbd487646fe3eb8c56c76a1cd59f73cca Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Thu, 28 Aug 2025 22:41:11 +0000 Subject: [PATCH 03/17] Convert FP8 block scaling tensors to MXFP8 tensors on Blackwell and newer in GEMM Signed-off-by: Jan Bielak --- .../pytorch/csrc/extensions/gemm.cpp | 17 +++++ transformer_engine/pytorch/csrc/util.cpp | 70 +++++++++++++++++++ transformer_engine/pytorch/csrc/util.h | 12 ++++ 3 files changed, 99 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f4768bb9ba..80196aa6ad 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -102,6 +102,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const bool low_precision = detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype()); + const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D; // Check tensor dimensions const auto& A_shape = A_tensor.shape(); @@ -197,6 +201,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (fp8_block_scaling && transformer_engine::cuda::sm_arch() > 90) { + // Convert tensors to mxfp8 and swizzle their scaling factors + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa))); + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb))); + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + if (comm_overlap) { // Prepare extra output tensor TensorWrapper extra_output_tensor; diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 92f2d3a500..524519d229 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -7,6 +7,7 @@ #include "util.h" #include "common.h" +#include "common/common.h" std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { @@ -170,3 +171,72 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } + +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, + bool rowwise) { + using namespace transformer_engine::pytorch; + using transformer_engine::DIVUP; + + // Check input tensor + const NVTEScalingMode scaling_mode = input.scaling_mode(); + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + + // Get tensor data + NVTEBasicTensor data; + size_t data_flat_first_dim = 1; + size_t data_flat_last_dim = 1; + if (rowwise) { + data = input.get_rowwise_data(); + for (int i = 0; i < data.shape.ndim - 1; ++i) { + data_flat_first_dim *= data.shape.data[i]; + } + data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; + } else { + data = input.get_columnwise_data(); + data_flat_first_dim = data.shape.data[0]; + for (int i = 1; i < data.shape.ndim; ++i) { + data_flat_last_dim *= data.shape.data[i]; + } + } + NVTEShape data_shape{}; + data_shape.data[0] = data_flat_first_dim; + data_shape.data[1] = data_flat_last_dim; + data_shape.ndim = 2; + + // Recreate input tensor with rowwise usage + transformer_engine::TensorWrapper input_cu(scaling_mode); + input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + const NVTEBasicTensor scale_inv = + rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); + input_cu.set_rowwise_scale_inv( + scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); + + // Create output tensor + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + // Output swizzled mxfp8 scaling factor dimensions + const size_t swizzled_scale_inv_first_dim = DIVUP(data_flat_first_dim, 128) * 128; + const size_t swizzled_scale_inv_last_dim = DIVUP(data_flat_last_dim, 128) * 4; + // Allocate memory for swizzled mxfp8 scaling factors + const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + at::Tensor swizzled_scale_inv = at::empty( + std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options); + // Set rowwise scaling factors on output + void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + NVTEShape swizzled_scale_inv_shape{}; + swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; + swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; + swizzled_scale_inv_shape.ndim = 2; + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + swizzled_scale_inv_shape); + + // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format + nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor + // for it to be kept alive during the GEMM + input = std::move(output_cu); + return swizzled_scale_inv; +} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 4b26860967..d9d2943547 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -27,4 +27,16 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); +/*! \brief Convert a block scaling tensor to an mxfp8 tensor. + * + * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid + * transposing it in memory. Due to differences in how block scaling and mxfp8 store data, + * this requires the calling code to treat the output tensor as having been tranposed in this case. + * + * Returns the swizzled scaling factor of the converted mxfp8 tensor. + * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + */ +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, + bool rowwise); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ From 4cab85c8d8dcaab68b9636ead404273994c7cd40 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 20 Aug 2025 18:09:19 +0000 Subject: [PATCH 04/17] Allow Blackwell and newer in Deepseek recipe compatbility check Signed-off-by: Jan Bielak --- transformer_engine/pytorch/fp8.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8f9dbd88d0..5f39ec5124 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -55,13 +55,12 @@ def check_mxfp8_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" - if ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ): + if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" - return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + return ( + False, + "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", + ) def check_recipe_support(recipe: Recipe) -> None: From 2dad396bca0b8a8e95f0196f08c01f0e49550f45 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Fri, 5 Sep 2025 17:52:09 -0700 Subject: [PATCH 05/17] Allow data_rows % 4 != 0 in 1d kernel Signed-off-by: Jan Bielak --- .../common/swizzle/swizzle_block_scaling.cu | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index 34b1f28650..65a9a92587 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -8,6 +8,7 @@ #include #include +#include #include "../common.h" #include "../util/logging.h" @@ -58,10 +59,21 @@ uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) { __byte_perm(x, 0, 0x3333)}; } +// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data +// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors. +struct no_oob_tag_t {}; +constexpr no_oob_tag_t NO_OOB_TAG; + +template void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel( const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, - const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride, + OOBT first_oob) { + // resolve kernel variant + constexpr bool no_oob = std::is_same_v; + static_assert(no_oob || std::is_same_v); + // load thread indices const uint32_t lane = threadIdx.x; __builtin_assume(lane < WARP_SIZE); @@ -87,7 +99,16 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) // load scaling factors for this lane's initial four 1x128 tiles const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); - uint4 sf = reinterpret_cast(warp_src)[lane_load_idx]; + uint4 sf; + if constexpr (no_oob) { + sf = reinterpret_cast(warp_src)[lane_load_idx]; + } else { + if ((out_tile_y < tiles_y - 1) || lane_load_idx < first_oob) { + sf = reinterpret_cast(warp_src)[lane_load_idx]; + } else { + sf = uint4{0, 0, 0, 0}; + } + } // pack the exponent bits of the scaling factors uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); @@ -111,8 +132,7 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui alignof(uint4), " bytes"); NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); - NVTE_CHECK(data_rows % 128 == 0, - "Input scaling factors have to be available for full 128x128 tiles"); + NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors"); const uint32_t tiles_x = DIVUP(data_cols, 128u); const uint32_t tiles_y = DIVUP(data_rows, 128u); @@ -124,8 +144,15 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui const uint32_t out_y_stride = tiles_x * 512; - swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( - in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); + const uint32_t first_oob = (input_scale_inv_cols % 128) / 4; + + if (first_oob == 0) { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, NO_OOB_TAG); + } else { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, first_oob); + } } } // namespace swizzle_kernel_1d namespace swizzle_kernel_2d { From aeafe79ad9b25f918b33cd8a704e7038102d083b Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Mon, 8 Sep 2025 23:57:43 +0000 Subject: [PATCH 06/17] Load scaling factors in unswizzled order in 1d kernel Signed-off-by: Jan Bielak --- .../common/swizzle/swizzle_block_scaling.cu | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index 65a9a92587..c0a378e9d3 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -98,13 +98,12 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; // load scaling factors for this lane's initial four 1x128 tiles - const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); uint4 sf; if constexpr (no_oob) { - sf = reinterpret_cast(warp_src)[lane_load_idx]; + sf = reinterpret_cast(warp_src)[lane]; } else { - if ((out_tile_y < tiles_y - 1) || lane_load_idx < first_oob) { - sf = reinterpret_cast(warp_src)[lane_load_idx]; + if ((out_tile_y < tiles_y - 1) || lane < first_oob) { + sf = reinterpret_cast(warp_src)[lane]; } else { sf = uint4{0, 0, 0, 0}; } @@ -113,8 +112,12 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) // pack the exponent bits of the scaling factors uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); - // transpose 4x4 matrices of scaling factors + // partially swizzle the scaling factors constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches + const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); + packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx); + + // transpose 4x4 matrices of scaling factors packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK); // broadcast the scaling factors for sixteen 1x32 tiles From e1334b657801dc84313f70e83a1e7b1480f67dbf Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 10 Sep 2025 20:26:49 +0000 Subject: [PATCH 07/17] Enforce use of power of two scaling Signed-off-by: Jan Bielak --- .../common/transpose/quantize_transpose_square_blockwise.cu | 6 ++++++ .../common/transpose/quantize_transpose_vector_blockwise.cu | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index c3f085b877..5399287194 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,6 +14,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -485,6 +486,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); + if (transformer_engine::cuda::sm_arch() > 90) { + NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 4c82b8c81b..71149c4f51 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,6 +17,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -529,6 +530,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); + if (transformer_engine::cuda::sm_arch() > 90) { + NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; size_t num_rows = 1; From aede6430c8931e3860b4bb6719cdb64761840b99 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 10 Sep 2025 20:40:28 +0000 Subject: [PATCH 08/17] Skip the FP8 block scaling exact GEMM test on Blackwell Signed-off-by: Jan Bielak --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index ec23cfe8c5..6d9258c6fd 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -19,7 +20,8 @@ def fp8_blockwise_gemm_supported() -> bool: supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() - return supported + emulated = get_device_compute_capability() > (9, 0) + return supported and not emulated def cublas_gemm_fp8_blockwise_case( From 4d7faacefde36c399d39f405b312dc6df67418cf Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 10 Sep 2025 22:25:19 +0000 Subject: [PATCH 09/17] Skip further tests with pow_2_scales=False Signed-off-by: Jan Bielak --- tests/pytorch/test_float8_blockwise_scaling_exact.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 858ce73b6b..1dcc854946 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8BlockScaling +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -32,6 +33,7 @@ if tensor_dump_dir_env is not None: TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() +recipe_emulated = get_device_compute_capability() > (9, 0) class GetRecipes: @@ -218,6 +220,10 @@ def check_quantization_block_tiling_versus_reference( pow_2_scales: bool, tile_size: Tuple[int, int], ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip("On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors.") + te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): block_scaling_dim = 1 @@ -409,6 +415,10 @@ def test_quantization_block_tiling_extrema_versus_reference( tile_size: Tuple[int, int], extrema_high: bool, ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip("On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors.") + # This test runs a single tile through a quantizer as a way to test # branch coverage of scale computation. te_dtype = TE_DType[quant_dtype] From 2a0af2af08472bdcd160a75483e95becf943ef41 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 22:25:50 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_float8_blockwise_scaling_exact.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 1dcc854946..ff78eb26ba 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -221,8 +221,10 @@ def check_quantization_block_tiling_versus_reference( tile_size: Tuple[int, int], ) -> None: if recipe_emulated and not pow_2_scales: - pytest.skip("On Blackwell and newer, the FP8 block scaling recipe is emulated " - "with MXFP8, which requires using power of two scaling factors.") + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): @@ -416,8 +418,10 @@ def test_quantization_block_tiling_extrema_versus_reference( extrema_high: bool, ) -> None: if recipe_emulated and not pow_2_scales: - pytest.skip("On Blackwell and newer, the FP8 block scaling recipe is emulated " - "with MXFP8, which requires using power of two scaling factors.") + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) # This test runs a single tile through a quantizer as a way to test # branch coverage of scale computation. From e288f56b5f8817185c964462e8b5324f0787252e Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 16 Sep 2025 20:38:20 +0000 Subject: [PATCH 11/17] Initial implementation of tensor conversion for grouped gemm Signed-off-by: Jan Bielak --- .../pytorch/csrc/extensions/gemm.cpp | 82 ++++++++++++++----- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 80196aa6ad..d504449964 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -355,15 +355,6 @@ std::optional> te_general_grouped_gemm( std::vector bias, DType bias_type, bool single_output, std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { - std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, - te_pre_gelu_out_vector, te_workspace_vector; - std::vector te_A_wrappers, te_B_wrappers, wrappers; - std::vector D_vectors; - - auto none = py::none(); - - std::vector single_output_begins; - std::vector single_output_ends; if (single_output && D == std::nullopt) { NVTE_ERROR("not implemented, D should be allocated for single output case."); } @@ -373,6 +364,10 @@ std::optional> te_general_grouped_gemm( output_data_ptr = (*D)[0].data_ptr(); } + const auto none = py::none(); + std::vector te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers, + te_pre_gelu_out_wrappers; + std::vector D_vectors; for (size_t i = 0; i < A.size(); i++) { auto te_A = makeTransformerEngineTensor(A[i], none); auto te_B = makeTransformerEngineTensor(B[i], none); @@ -438,29 +433,72 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); - te_A_vector.emplace_back(te_A.data()); - te_B_vector.emplace_back(te_B.data()); - te_D_vector.emplace_back(te_D.data()); - te_bias_vector.emplace_back(te_bias.data()); - te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - te_A_wrappers.emplace_back(std::move(te_A)); te_B_wrappers.emplace_back(std::move(te_B)); - wrappers.emplace_back(std::move(te_D)); - wrappers.emplace_back(std::move(te_bias)); - wrappers.emplace_back(std::move(te_pre_gelu_out)); + te_D_wrappers.emplace_back(std::move(te_D)); + te_bias_wrappers.emplace_back(std::move(te_bias)); + te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } + // Keep the swizzled scaling factor tensors alive during the GEMM. + std::vector> swizzled_scale_inverses_list; + // Optionally swizzle the scaling factors - // Keep the swizzled scaling factor tensors alive during the GEMMs. - auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); - auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa)); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb)); + + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (transformer_engine::cuda::sm_arch() > 90) { + // Check if is using FP8 block scaling + bool exists_tensor_using_fp8_block_scaling = false; + bool exists_tensor_not_using_fp8_block_scaling = false; + for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) { + for (const TensorWrapper& tensor : *tensor_wrappers) { + const NVTEScalingMode scaling_mode = tensor.scaling_mode(); + if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) + exists_tensor_using_fp8_block_scaling = true; + else + exists_tensor_not_using_fp8_block_scaling = true; + } + } + if (exists_tensor_using_fp8_block_scaling) { + NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling, + "Either all tensors or no tensor must be FP8 block scaling tensors"); + // Convert tensors to mxfp8 and swizzle their scaling factors + for (TensorWrapper& A_tensor : te_A_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)); + } + for (TensorWrapper& B_tensor : te_B_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)); + } + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + } + + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector; + for (size_t i = 0; i < te_A_wrappers.size(); i++) { + te_A_vector.emplace_back(te_A_wrappers[i].data()); + te_B_vector.emplace_back(te_B_wrappers[i].data()); + te_D_vector.emplace_back(te_D_wrappers[i].data()); + te_bias_vector.emplace_back(te_bias_wrappers[i].data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data()); + } + std::vector te_workspace_vector; + std::vector te_workspace_wrappers; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); - wrappers.emplace_back(std::move(wsp)); + te_workspace_wrappers.emplace_back(std::move(wsp)); } // For now, we only have multi-stream cublas backend. From 2879c6eb60e57b290421b4b03e9469ff216ab663 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Thu, 18 Sep 2025 00:14:20 +0000 Subject: [PATCH 12/17] Skip non power of two scaling cpp unit tests Signed-off-by: Jan Bielak --- tests/cpp/operator/test_cast_float8blockwise.cu | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index e5faa688ce..b8f2a63e93 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -501,6 +501,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 2u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() > hopperComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. @@ -552,6 +558,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 1u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() > hopperComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. From b9e6d16c73a065e8a94dfeece33ae8860fc0657a Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Thu, 18 Sep 2025 01:14:19 +0000 Subject: [PATCH 13/17] Fix handling of all gather Signed-off-by: Jan Bielak --- transformer_engine/pytorch/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 217cb98c74..d3aa066db1 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1013,10 +1013,10 @@ def _post_process_fp8_blockwise_gather( return out needs_columnwise_data_transpose = ( - quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() + quantizer is not None and quantizer.columnwise_usage ) need_rowwise_scale_transpose = ( - quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported() + quantizer is not None and quantizer.rowwise_usage ) # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 From d0843e9ddb21133cccaa9db733e65dcaadd40106 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Sep 2025 01:14:49 +0000 Subject: [PATCH 14/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/distributed.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index d3aa066db1..fdc83e3fd7 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1012,12 +1012,8 @@ def _post_process_fp8_blockwise_gather( if out._is_gemm_ready_format(): return out - needs_columnwise_data_transpose = ( - quantizer is not None and quantizer.columnwise_usage - ) - need_rowwise_scale_transpose = ( - quantizer is not None and quantizer.rowwise_usage - ) + needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage + need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 # columnwise compact format means doing 128x1 quantization of it From 293e8321fe023e27a0cdebedcb7f2e37b7221544 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 23 Sep 2025 14:48:53 -0700 Subject: [PATCH 15/17] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jan Bielak --- transformer_engine/common/swizzle/swizzle_block_scaling.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index c0a378e9d3..b245317c22 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -285,7 +285,7 @@ void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* " rows, but it has ", input_scale_inv_rows, " rows instead."); NVTE_CHECK(input_scale_inv_cols == DIVUP(data_rows, 4) * 4, "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 4) * 4, - "columns, but it has ", input_scale_inv_cols, " columns instead."); + " columns, but it has ", input_scale_inv_cols, " columns instead."); swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, data_cols, stream); @@ -295,7 +295,7 @@ void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* " rows, but it has ", input_scale_inv_rows, " rows instead."); NVTE_CHECK(input_scale_inv_cols == DIVUP(data_cols, 512) * 4, "Expected the input scaling factor matrix to have ", - DIVUP(data_cols, 512) * 4, "columns, but it has ", input_scale_inv_cols, + DIVUP(data_cols, 512) * 4, " columns, but it has ", input_scale_inv_cols, " columns instead."); swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, From b9271622b73d49790f7f68d388068000dc4badff Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 1 Oct 2025 16:45:36 -0700 Subject: [PATCH 16/17] Use compute capability 10.0 for logic with Blackwell Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/cpp/operator/test_cast_float8blockwise.cu | 4 ++-- tests/pytorch/test_float8_blockwise_gemm_exact.py | 2 +- tests/pytorch/test_float8_blockwise_scaling_exact.py | 2 +- .../common/transpose/quantize_transpose_square_blockwise.cu | 2 +- .../common/transpose/quantize_transpose_vector_blockwise.cu | 2 +- transformer_engine/pytorch/csrc/extensions/gemm.cpp | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index b8f2a63e93..fe4ae2d264 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -503,7 +503,7 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, // which requires using power of two scaling factors. Skip unsupported tests. - if (getDeviceComputeCapability() > hopperComputeCapability && !force_pow_2) { + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { GTEST_SKIP(); } @@ -560,7 +560,7 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, // which requires using power of two scaling factors. Skip unsupported tests. - if (getDeviceComputeCapability() > hopperComputeCapability && !force_pow_2) { + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { GTEST_SKIP(); } diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 6d9258c6fd..bdc73519be 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -20,7 +20,7 @@ def fp8_blockwise_gemm_supported() -> bool: supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() - emulated = get_device_compute_capability() > (9, 0) + emulated = get_device_compute_capability() >= (10, 0) return supported and not emulated diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index ff78eb26ba..51e0d1ec9b 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -33,7 +33,7 @@ if tensor_dump_dir_env is not None: TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() -recipe_emulated = get_device_compute_capability() > (9, 0) +recipe_emulated = get_device_compute_capability() >= (10, 0) class GetRecipes: diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 5399287194..661cf339ae 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -486,7 +486,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); - if (transformer_engine::cuda::sm_arch() > 90) { + if (transformer_engine::cuda::sm_arch() >= 100) { NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", "with MXFP8, which requires using power of two scaling factors."); } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 024bfab6d0..fcf7a151c3 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -530,7 +530,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); - if (transformer_engine::cuda::sm_arch() > 90) { + if (transformer_engine::cuda::sm_arch() >= 100) { NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", "with MXFP8, which requires using power of two scaling factors."); } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 4204b1457d..3c26dd5f1c 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -228,7 +228,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt - if (fp8_block_scaling && transformer_engine::cuda::sm_arch() > 90) { + if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { // Convert tensors to mxfp8 and swizzle their scaling factors swizzled_scale_inverses_list.emplace_back( std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa))); From 87f5db5cf75a7bcaf2ab03fe2862d14bf5f6bcf0 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 2 Oct 2025 19:16:45 -0700 Subject: [PATCH 17/17] Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../common/swizzle/swizzle_block_scaling.cu | 15 +++++++++++---- .../pytorch/csrc/extensions/gemm.cpp | 2 +- transformer_engine/pytorch/csrc/util.h | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu index b245317c22..4be85474af 100644 --- a/transformer_engine/common/swizzle/swizzle_block_scaling.cu +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -39,6 +39,9 @@ uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row const uint32_t active_mask) { using cu = const uint32_t; + // Threads operate in groups of 4, and each thread stores 4 bytes at a time. + // The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes + // until we have transposed the 4x4 matrix. cu m_0123_4567_89ab_cdef = row; cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4); cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240); @@ -142,10 +145,12 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + // Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales + // and a 128x4 tile in the output scales. The input scales are in transposed order. const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); - - const uint32_t out_y_stride = tiles_x * 512; + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); const uint32_t first_oob = (input_scale_inv_cols % 128) / 4; @@ -217,10 +222,12 @@ void launch_kernel(const void* const in, void* const out, uint32_t data_rows, ui const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + // Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales + // and a 128x4 tile in the output scales. const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); - - const uint32_t out_y_stride = tiles_x * 512; + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<>>( in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index de9b1f3f9f..15404ad9a6 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -492,7 +492,7 @@ std::optional> te_general_grouped_gemm( // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt - if (transformer_engine::cuda::sm_arch() > 90) { + if (transformer_engine::cuda::sm_arch() >= 100) { // Check if is using FP8 block scaling bool exists_tensor_using_fp8_block_scaling = false; bool exists_tensor_not_using_fp8_block_scaling = false; diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index d9d2943547..57eee86d2a 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -27,7 +27,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); -/*! \brief Convert a block scaling tensor to an mxfp8 tensor. +/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. * * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid * transposing it in memory. Due to differences in how block scaling and mxfp8 store data,