[Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell#2157
Conversation
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
d7e794a to
7e7bf91
Compare
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
…ewer in GEMM Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
7e7bf91 to
aeafe79
Compare
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L0 |
|
/te-ci pytorch L0 |
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
|
/te-ci L0 |
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L0 L1 |
|
/te-ci pytorch L0 |
|
/te-ci pytorch L0 |
|
/te-ci L0 L1 |
There was a problem hiding this comment.
Pull Request Overview
This PR enables FP8 Block Scaling (DeepSeek recipe) support on Blackwell architecture by emulating it with MXFP8, as native FP8 block scaling is only supported on Hopper. The implementation converts FP8 block scaling tensors to MXFP8 format before GEMM operations while preserving the original quantization behavior.
Key Changes:
- Removes architecture restriction limiting FP8 block scaling to Hopper only
- Adds tensor conversion from FP8 block scaling to MXFP8 format for Blackwell compatibility
- Implements custom CUDA kernels for scaling factor format conversion
Reviewed Changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| transformer_engine/pytorch/fp8.py | Updates support check to allow Blackwell architecture |
| transformer_engine/pytorch/distributed.py | Simplifies transpose logic for block scaling tensors |
| transformer_engine/pytorch/csrc/util.h | Adds function declaration for block scaling to MXFP8 conversion |
| transformer_engine/pytorch/csrc/util.cpp | Implements tensor conversion from block scaling to MXFP8 format |
| transformer_engine/pytorch/csrc/extensions/gemm.cpp | Adds Blackwell-specific tensor conversion logic in GEMM operations |
| transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu | Adds power-of-two validation for Blackwell |
| transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu | Adds power-of-two validation for Blackwell |
| transformer_engine/common/transformer_engine.cpp | Adds string representations for block scaling modes |
| transformer_engine/common/swizzle/swizzle_block_scaling.cu | New CUDA kernels for scaling factor format conversion |
| transformer_engine/common/include/transformer_engine/swizzle.h | Adds API declaration for scaling factor swizzling |
| transformer_engine/common/CMakeLists.txt | Adds new swizzle kernel source file |
| tests/pytorch/test_float8_blockwise_scaling_exact.py | Skips non-power-of-two tests on Blackwell |
| tests/pytorch/test_float8_blockwise_gemm_exact.py | Disables exact GEMM tests for emulated mode |
| tests/cpp/operator/test_cast_float8blockwise.cu | Skips non-power-of-two tests on Blackwell |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jan Bielak <jbielak@nvidia.com>
b9227d8 to
293e832
Compare
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
Outdated
Show resolved
Hide resolved
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
|
/te-ci L1 |
| NVTE_CHECK(output_scale_inv_rows == DIVUP<size_t>(data_rows, 128) * 128, | ||
| "Expected the output scaling factor matrix to have ", | ||
| DIVUP<size_t>(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows, | ||
| " rows instead."); | ||
| NVTE_CHECK(output_scale_inv_cols == DIVUP<size_t>(data_cols, 128) * 4, | ||
| "Expected the output scaling factor matrix to have ", | ||
| DIVUP<size_t>(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols, | ||
| " columns instead."); |
There was a problem hiding this comment.
Be advised that these dimensions are lies. The MXFP8 scales are actually laid out with dims (data_rows / 128, data_cols / 128, 128*4).
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
|
/te-ci L1 |
|
Hi @janekb04, just to confirm, with this PR, we still need both rowwise and colwise (transposed) FP8 weights on Blackwell, right? Even though Blackwell supports all layouts for FP8 GEMMs and the colwise quantized weights are just transposed rowwise ones because of the 2D scaling. Am I understanding it correctly? |
Yes, this PR still uses both rowwise and columnwise data. Indeed, in the case of 2d block scaling, this could be avoided because the columnwise data is the transpose of the rowwise data. In case of 1d block scaling, however, this could increase numerical differences because the data would be quantized in the other direction, and would not necessarily be the same after dequantization. In the section on future optimizations I point the two optimizations regarding this I found that could be applicable:
|
Description
This PR adds support for the FP8 block scaling (ie. DeepSeek) recipe on Blackwell. It exhibits some changes in behavior compared to Hopper.
Addresses this discussion from #1513.
Motivation
Currently, the FP8 block scaling recipe works only on Hopper. If you try to use it on Blackwell, the
check_fp8_block_scaling_supportfunction infp8.pywill report that is not supported and an exception will prevent further execution. Ifcheck_fp8_block_scaling_supportis changed to instead check that the architecture is Hopper or newer, the failure occurs incublas_gemminstead. Namely, cuBLASLt does not implementcublasLtMatmulwith aCUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32ForCUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32Finput type on Blackwell.A possible workaround is to simply switch from using the
Float8BlockScalingrecipe to theMXFP8BlockScalingrecipe. However, this can result in numerical discrepancies. They occur because of differences in how the two recipes quantize tensors and because the MXFP8 recipe performs more operations in low precision than the block scaling recipe.Implementation
This PR emulates only the GEMMs with MXFP8. This is done by converting input
NVTE_BLOCK_SCALING_1DandNVTE_BLOCK_SCALING_2Dtensors toNVTE_MXFP8_1D_SCALINGjust before a GEMM. The tensors' main data is not touched at all, only the format of the scaling factors is changed to be compatible with MXFP8.The FP8 block scaling tensors are created (quantized from higher precision) using
quantize_transpose_vector_blockwiseorquantize_transpose_square_blockwise- the same as on Hopper. This means that contrary to simply switching to the MXFP8 recipe, the 1x128 and 128x128 quantization block size is preserved (if MXFP8 was used, the scaling factors could be different, as they would correspond to 1x32 blocks).To make the tensors valid inputs to the MXFP8 GEMM, their scaling factors are converted from the FP8 block scaling format to the MXFP8 format. I take advantage of the fact that when entering the GEMM, the FP8 Block Scaling tensors are guaranteed to already use the
GEMM_READYscaling factor format. In case of 2D (128x128) block scaling, the scaling factors are simply "unsqueezed" by a factor of 512 - a single scaling factor for a 128x128 block becomes 512 scaling factors for 512 1x32 blocks constituting the 128x128 block. In case of 1D (1x128) block scaling, every 128 scaling factors corresponding to a 128x128 block are swizzled to match the cuBLASLt MXFP8 GEMM format and "unsqueezed" by a factor of 4 to correspond to the 512 1x32 blocks.Limitations
The conversion from FP8 block scaling scaling factors to MXFP8 scaling factors is lossless if and only if the original FP8 block scaling scaling factors are powers of 2. This is because the MXFP8 scaling factors are not FP32 (like the FP8 block scaling scaling factors), but rather FP8E8M0. The conversion kernels assume this requirement is met and simply extract the FP32 exponent bits and treat them as FP8E8M0. If either the sign bit or any mantissa bits are set, the results will be incorrect, as the exponent bits are not masked out when performing bit shifts. Masking them out would result in numerical discrepancies in the output anyway due to discarding them.
Despite losslessly converting the tensors, the GEMM outputs are not identical between Hopper and Blackwell because the Blackwell MXFP8 cuBLASLt GEMM is implemented differently from the Hopper FP8 block scaling cuBLASLt GEMM. However, the numerical error overall should be smaller compared to simply switching to the MXFP8 recipe.
Contrary to FP8 Block Scaling on Hopper, GEMM+GELU fusion is not currently supported as cuBLASLt doesn't support it for MXFP8 (with BF16 output, which the FP8 Block Scaling recipe uses).
Future optimizations to pursue
Take advantage of Blackwell support for FP8 non-TN GEMMs. Contrary to Hopper, Blackwell supports non-TN GEMMs for FP8. This leads to at least the following two optimizations:
_post_process_fp8_blockwise_gatherindistributed.pytransposes columnwise data after all gather. This could be avoided.Use a multi kernel for scaling factor swizzling. Currently, the support for
te_general_grouped_gemm, which is needed for MoE, is naive. The tensor conversion functionconvert_block_scaling_to_mxfp8_tensoris simply called in a loop for every tensor. This can result in many small kernel launches, which is inefficient. Instead, a multi kernel approach could be used, similar to how MXFP8 scaling factor swizzling is handled.Use preloading to reduce scoreboard stalls. Currently, the kernels are memory-latency bound. Before they can start loading data, they have to do a bit of address calculation, which prevents latency hiding. To alleviate this, the kernels could prefetch the entire scaling factor tensor to L2 with
cp.async.bulk.prefetch.global.L2at the very beginning before calculating which part of the tensor they need to read specifically, as the tensors should be small enough to fit in L2, and they will have to be read in their entirety anyway.Performance notes
Regarding the performance of the 1D swizzling kernel:
lane_load_idxbefore the scaling factors being loaded and indexreinterpret_cast<const uint4*>(warp_src)withlane_load_idxrather thanlane? We could, and this is what I originally did. However, it turned out that it is more performant to start loading the data as soon as possible, minimizing the amount of computations that have to be done before the first load. As such, I changed the kernel to the current approach in aeafe79. After preloading the tensors to L2, this might no longer be neccesary.Type of change
Changes
nvte_swizzle_block_scaling_to_mxfp8_scaling_factorscore function for swizzling the FP8 block scalingGEMM_READYscaling factors to MXFP8 swizzled format.swizzle_block_scaling.cuwith two new custom kernels for swizzling the 1D and 2D FP8 block scaling scaling factors.quantize_transpose_square_blockwise.cuandquantize_transpose_vector_blockwise.cu) make sure the power-of-two scaling factors are used. The quantization itself does work with non-power-of-two scaling factors, but the check is done here, as it is the only place where this can be checked. In the GEMM, there is no longer a way to determine whether the tensors use power-of-two scaling factors.gemmandte_general_grouped_gemmingemm.cpp.fp8.py.test_cast_float8blockwise.cuandtest_float8_blockwise_scaling_exact.py.test_float8_blockwise_gemm_exact.pybecause the MXFP8 GEMM has some numerical discrepancies with the FP8 Blockwise Scaling GEMM and is not bitwise the same as these tests expect. Other tests liketest_numerics.pywill check that the GEMMs are within acceptable numerical errors.Checklist: