Skip to content

[Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell#2157

Merged
timmoon10 merged 25 commits intoNVIDIA:mainfrom
janekb04:deepseek-blackwell
Oct 3, 2025
Merged

[Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell#2157
timmoon10 merged 25 commits intoNVIDIA:mainfrom
janekb04:deepseek-blackwell

Conversation

@janekb04
Copy link
Collaborator

@janekb04 janekb04 commented Sep 5, 2025

Description

This PR adds support for the FP8 block scaling (ie. DeepSeek) recipe on Blackwell. It exhibits some changes in behavior compared to Hopper.

Addresses this discussion from #1513.

Motivation

Currently, the FP8 block scaling recipe works only on Hopper. If you try to use it on Blackwell, the check_fp8_block_scaling_support function in fp8.py will report that is not supported and an exception will prevent further execution. If check_fp8_block_scaling_support is changed to instead check that the architecture is Hopper or newer, the failure occurs in cublas_gemm instead. Namely, cuBLASLt does not implement cublasLtMatmul with a CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F or CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F input type on Blackwell.

A possible workaround is to simply switch from using the Float8BlockScaling recipe to the MXFP8BlockScaling recipe. However, this can result in numerical discrepancies. They occur because of differences in how the two recipes quantize tensors and because the MXFP8 recipe performs more operations in low precision than the block scaling recipe.

Implementation

This PR emulates only the GEMMs with MXFP8. This is done by converting input NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D tensors to NVTE_MXFP8_1D_SCALING just before a GEMM. The tensors' main data is not touched at all, only the format of the scaling factors is changed to be compatible with MXFP8.

The FP8 block scaling tensors are created (quantized from higher precision) using quantize_transpose_vector_blockwise or quantize_transpose_square_blockwise - the same as on Hopper. This means that contrary to simply switching to the MXFP8 recipe, the 1x128 and 128x128 quantization block size is preserved (if MXFP8 was used, the scaling factors could be different, as they would correspond to 1x32 blocks).

To make the tensors valid inputs to the MXFP8 GEMM, their scaling factors are converted from the FP8 block scaling format to the MXFP8 format. I take advantage of the fact that when entering the GEMM, the FP8 Block Scaling tensors are guaranteed to already use the GEMM_READY scaling factor format. In case of 2D (128x128) block scaling, the scaling factors are simply "unsqueezed" by a factor of 512 - a single scaling factor for a 128x128 block becomes 512 scaling factors for 512 1x32 blocks constituting the 128x128 block. In case of 1D (1x128) block scaling, every 128 scaling factors corresponding to a 128x128 block are swizzled to match the cuBLASLt MXFP8 GEMM format and "unsqueezed" by a factor of 4 to correspond to the 512 1x32 blocks.

Limitations

  1. The conversion from FP8 block scaling scaling factors to MXFP8 scaling factors is lossless if and only if the original FP8 block scaling scaling factors are powers of 2. This is because the MXFP8 scaling factors are not FP32 (like the FP8 block scaling scaling factors), but rather FP8E8M0. The conversion kernels assume this requirement is met and simply extract the FP32 exponent bits and treat them as FP8E8M0. If either the sign bit or any mantissa bits are set, the results will be incorrect, as the exponent bits are not masked out when performing bit shifts. Masking them out would result in numerical discrepancies in the output anyway due to discarding them.

  2. Despite losslessly converting the tensors, the GEMM outputs are not identical between Hopper and Blackwell because the Blackwell MXFP8 cuBLASLt GEMM is implemented differently from the Hopper FP8 block scaling cuBLASLt GEMM. However, the numerical error overall should be smaller compared to simply switching to the MXFP8 recipe.

  3. Contrary to FP8 Block Scaling on Hopper, GEMM+GELU fusion is not currently supported as cuBLASLt doesn't support it for MXFP8 (with BF16 output, which the FP8 Block Scaling recipe uses).

Future optimizations to pursue

  • Take advantage of Blackwell support for FP8 non-TN GEMMs. Contrary to Hopper, Blackwell supports non-TN GEMMs for FP8. This leads to at least the following two optimizations:

    1. Don't create columnwise data for weights. Because weights are 2D-block-scaled, their quantized columnwise data is simply the transpose of the rowwise data. As such, columnwise data doesn't have to be created in the forward pass, as the rowwise data can be used in the dgad GEMM.
    2. Don't transpose data after all gather. Currently, _post_process_fp8_blockwise_gather in distributed.py transposes columnwise data after all gather. This could be avoided.
  • Use a multi kernel for scaling factor swizzling. Currently, the support for te_general_grouped_gemm, which is needed for MoE, is naive. The tensor conversion function convert_block_scaling_to_mxfp8_tensor is simply called in a loop for every tensor. This can result in many small kernel launches, which is inefficient. Instead, a multi kernel approach could be used, similar to how MXFP8 scaling factor swizzling is handled.

  • Use preloading to reduce scoreboard stalls. Currently, the kernels are memory-latency bound. Before they can start loading data, they have to do a bit of address calculation, which prevents latency hiding. To alleviate this, the kernels could prefetch the entire scaling factor tensor to L2 with cp.async.bulk.prefetch.global.L2 at the very beginning before calculating which part of the tensor they need to read specifically, as the tensors should be small enough to fit in L2, and they will have to be read in their entirety anyway.

Performance notes

Regarding the performance of the 1D swizzling kernel:

  • Using warp shuffles (as it is currently being done) is 4.7% faster than using shared memory, according to NCU. I think that using a grid-stride loop would improve this benefit even further if shuffles from the different tiles could be interleaved for latency hiding.
  • It may seem like the first shuffle is redundant. After all, couldn't we just calculate lane_load_idx before the scaling factors being loaded and index reinterpret_cast<const uint4*>(warp_src) with lane_load_idx rather than lane? We could, and this is what I originally did. However, it turned out that it is more performant to start loading the data as soon as possible, minimizing the amount of computations that have to be done before the first load. As such, I changed the kernel to the current approach in aeafe79. After preloading the tensors to L2, this might no longer be neccesary.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add nvte_swizzle_block_scaling_to_mxfp8_scaling_factors core function for swizzling the FP8 block scaling GEMM_READY scaling factors to MXFP8 swizzled format.
  • Add swizzle_block_scaling.cu with two new custom kernels for swizzling the 1D and 2D FP8 block scaling scaling factors.
  • On Blackwell, when quantizing tensors to the FP8 Block Scaling format (in quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu) make sure the power-of-two scaling factors are used. The quantization itself does work with non-power-of-two scaling factors, but the check is done here, as it is the only place where this can be checked. In the GEMM, there is no longer a way to determine whether the tensors use power-of-two scaling factors.
  • On Blackwell, convert FP8 Block Scaling tensors to MXFP8 tensors right before the GEMM in gemm and te_general_grouped_gemm in gemm.cpp.
  • Change testing for support of the FP8 Block Scaling recipe to allow in on Blackwell in the Pytorch frontend in fp8.py.
  • On Blackwell, skip tests with non power of two scaling factors in test_cast_float8blockwise.cu and test_float8_blockwise_scaling_exact.py.
  • On Blackwell, skip test_float8_blockwise_gemm_exact.py because the MXFP8 GEMM has some numerical discrepancies with the FP8 Blockwise Scaling GEMM and is not bitwise the same as these tests expect. Other tests like test_numerics.py will check that the GEMMs are within acceptable numerical errors.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@janekb04 janekb04 changed the title Deepseek blackwell Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell Sep 5, 2025
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
@janekb04 janekb04 force-pushed the deepseek-blackwell branch 4 times, most recently from d7e794a to 7e7bf91 Compare September 9, 2025 21:59
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
…ewer in GEMM

Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
janekb04 and others added 4 commits September 10, 2025 20:26
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
@janekb04
Copy link
Collaborator Author

/te-ci pytorch L0

@janekb04
Copy link
Collaborator Author

/te-ci pytorch L0

@janekb04
Copy link
Collaborator Author

/te-ci L0

janekb04 and others added 4 commits September 18, 2025 00:14
@janekb04
Copy link
Collaborator Author

/te-ci L0 L1

@janekb04
Copy link
Collaborator Author

/te-ci pytorch L0

@janekb04
Copy link
Collaborator Author

/te-ci pytorch L0

@janekb04 janekb04 marked this pull request as ready for review September 18, 2025 16:52
@janekb04
Copy link
Collaborator Author

/te-ci L0 L1

@timmoon10 timmoon10 removed the 2.8.0 label Sep 23, 2025
@janekb04 janekb04 requested a review from Copilot September 23, 2025 21:45
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables FP8 Block Scaling (DeepSeek recipe) support on Blackwell architecture by emulating it with MXFP8, as native FP8 block scaling is only supported on Hopper. The implementation converts FP8 block scaling tensors to MXFP8 format before GEMM operations while preserving the original quantization behavior.

Key Changes:

  • Removes architecture restriction limiting FP8 block scaling to Hopper only
  • Adds tensor conversion from FP8 block scaling to MXFP8 format for Blackwell compatibility
  • Implements custom CUDA kernels for scaling factor format conversion

Reviewed Changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
transformer_engine/pytorch/fp8.py Updates support check to allow Blackwell architecture
transformer_engine/pytorch/distributed.py Simplifies transpose logic for block scaling tensors
transformer_engine/pytorch/csrc/util.h Adds function declaration for block scaling to MXFP8 conversion
transformer_engine/pytorch/csrc/util.cpp Implements tensor conversion from block scaling to MXFP8 format
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds Blackwell-specific tensor conversion logic in GEMM operations
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu Adds power-of-two validation for Blackwell
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Adds power-of-two validation for Blackwell
transformer_engine/common/transformer_engine.cpp Adds string representations for block scaling modes
transformer_engine/common/swizzle/swizzle_block_scaling.cu New CUDA kernels for scaling factor format conversion
transformer_engine/common/include/transformer_engine/swizzle.h Adds API declaration for scaling factor swizzling
transformer_engine/common/CMakeLists.txt Adds new swizzle kernel source file
tests/pytorch/test_float8_blockwise_scaling_exact.py Skips non-power-of-two tests on Blackwell
tests/pytorch/test_float8_blockwise_gemm_exact.py Disables exact GEMM tests for emulated mode
tests/cpp/operator/test_cast_float8blockwise.cu Skips non-power-of-two tests on Blackwell

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jan Bielak <jbielak@nvidia.com>
@janekb04 janekb04 changed the title Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell [Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell Sep 24, 2025
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Collaborator

/te-ci L1

Comment on lines +273 to +280
NVTE_CHECK(output_scale_inv_rows == DIVUP<size_t>(data_rows, 128) * 128,
"Expected the output scaling factor matrix to have ",
DIVUP<size_t>(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows,
" rows instead.");
NVTE_CHECK(output_scale_inv_cols == DIVUP<size_t>(data_cols, 128) * 4,
"Expected the output scaling factor matrix to have ",
DIVUP<size_t>(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols,
" columns instead.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be advised that these dimensions are lies. The MXFP8 scales are actually laid out with dims (data_rows / 128, data_cols / 128, 128*4).

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Collaborator

/te-ci L1

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@timmoon10 timmoon10 merged commit dfe5b7d into NVIDIA:main Oct 3, 2025
49 of 53 checks passed
@yaox12
Copy link
Member

yaox12 commented Nov 5, 2025

Hi @janekb04, just to confirm, with this PR, we still need both rowwise and colwise (transposed) FP8 weights on Blackwell, right? Even though Blackwell supports all layouts for FP8 GEMMs and the colwise quantized weights are just transposed rowwise ones because of the 2D scaling. Am I understanding it correctly?

@janekb04
Copy link
Collaborator Author

janekb04 commented Nov 5, 2025

Hi @janekb04, just to confirm, with this PR, we still need both rowwise and colwise (transposed) FP8 weights on Blackwell, right? Even though Blackwell supports all layouts for FP8 GEMMs and the colwise quantized weights are just transposed rowwise ones because of the 2D scaling. Am I understanding it correctly?

Yes, this PR still uses both rowwise and columnwise data. Indeed, in the case of 2d block scaling, this could be avoided because the columnwise data is the transpose of the rowwise data. In case of 1d block scaling, however, this could increase numerical differences because the data would be quantized in the other direction, and would not necessarily be the same after dequantization.

In the section on future optimizations I point the two optimizations regarding this I found that could be applicable:

Don't create columnwise data for weights. Because weights are 2D-block-scaled, their quantized columnwise data is simply the transpose of the rowwise data. As such, columnwise data doesn't have to be created in the forward pass, as the rowwise data can be used in the dgad GEMM.
Don't transpose data after all gather. Currently, _post_process_fp8_blockwise_gather in distributed.py transposes columnwise data after all gather. This could be avoided.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants