From c045997f05fe545818ee4b1b251760c2f6c05beb Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 12 Feb 2025 13:50:52 -0800 Subject: [PATCH 01/38] Blockwise float8 quantizer and quantized tensor class. The classes are configurable for 128x128 blocksize and 1x128 blocksize via setting block_scaling_dim == 2,1 respectively. Scale tensors are stored in a format emenable for matrix multiplication, however the integration of matmul is deferred as a separate story. Fusions of quantization and DBIAS or activation functions are not yet implemented, and the dequantization is currently implemented in torch. Tests for quantization are included in C++ and pytorch layers, with exact comparison to reference quantizer behavior as well as an attempt to hit interesting branches through the API such as tensor creation in pytorch and CPP and dequantization of row and columnwise usage. Two CUDA kernels for quantization are included, and are direct ports of equivalents in the kitchen repository, where a subchannel recipe has been used for end to end training. Signed-off-by: Keith Wyss --- tests/cpp/operator/CMakeLists.txt | 1 + .../cpp/operator/test_cast_float8blockwise.cu | 640 ++++++++++++++++++ tests/cpp/test_common.cu | 159 +++-- tests/cpp/test_common.h | 35 +- .../blockwise_quantizer_reference.py | 361 ++++++++++ .../test_float8_blockwise_scaling_exact.py | 291 ++++++++ tests/pytorch/test_float8blockwisetensor.py | 201 ++++++ transformer_engine/common/CMakeLists.txt | 2 + transformer_engine/common/common.h | 36 +- .../common/gemm/cublaslt_gemm.cu | 8 + .../common/include/transformer_engine/cast.h | 14 +- .../transformer_engine/transformer_engine.h | 84 ++- .../common/transformer_engine.cpp | 45 ++ .../common/transpose/cast_transpose.h | 12 + .../common/transpose/compute_scale.cuh | 134 ++++ .../quantize_transpose_square_blockwise.cu | 603 +++++++++++++++++ .../quantize_transpose_vector_blockwise.cu | 528 +++++++++++++++ .../common/util/cast_kernels.cuh | 21 + .../common/util/dequantize_kernels.cuh | 1 + transformer_engine/pytorch/constants.py | 6 + transformer_engine/pytorch/csrc/common.h | 30 + .../pytorch/csrc/extensions/pybind.cpp | 26 + .../pytorch/csrc/extensions/quantizer.cpp | 128 ++++ .../csrc/extensions/type_converters.cpp | 32 + transformer_engine/pytorch/csrc/pybind.h | 21 +- .../_internal/float8_blockwise_tensor_base.py | 246 +++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 539 +++++++++++++++ 27 files changed, 4129 insertions(+), 75 deletions(-) create mode 100644 tests/cpp/operator/test_cast_float8blockwise.cu create mode 100644 tests/pytorch/references/blockwise_quantizer_reference.py create mode 100644 tests/pytorch/test_float8_blockwise_scaling_exact.py create mode 100644 tests/pytorch/test_float8blockwisetensor.py create mode 100644 transformer_engine/common/transpose/compute_scale.cuh create mode 100644 transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu create mode 100644 transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu create mode 100644 transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/float8_blockwise_tensor.py diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 6785dbf6f4..0b0e615495 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu new file mode 100644 index 0000000000..171d22be71 --- /dev/null +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -0,0 +1,640 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +constexpr size_t kBlockLen = 128; + +enum ProcessingMethod { + CAST_ONLY, + // CAST_DBIAS, + // CAST_DBIAS_DACT, + // CAST_DACT, + // CAST_ACT +}; + +enum ActivationType { + Identity, + // GeLU, + // SiLU, + // ReLU, + // QGeLU, + // SReLU +}; + +template +void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale_out, + float* qscale_inv_out) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + float eps = opts.amax_epsilon; + amax = std::max(amax, eps); + float qscale = quant_type_max_val / amax; + if (std::isinf(qscale)) { + qscale = input_type_max_val; + } + if (std::isnan(qscale) || amax == 0) { + qscale = 1.0; + } + + if (opts.force_pow_2_scales && qscale != 0.0) { + uint32_t scale_bits = *reinterpret_cast(&qscale); + // Scale must be positive, shift it + uint8_t exp = scale_bits >> 23; + ASSERT_FALSE(exp == 0) << "Subnormals in this path is a logic error."; + qscale = ldexpf(1.0f, static_cast(exp) - 127); + } + + float qscale_inv = 1.0 / qscale; + *qscale_out = qscale; + *qscale_inv_out = qscale_inv; +} + +template +void ref_quantize(const ProcessingMethod processing_method, const InputType* input, + const std::pair& input_hw, OutputType* output, float* scale_inv, + OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) { + constexpr size_t kBlockLenX = kBlockLen; + constexpr size_t kBlockLenY = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_y = (height + kBlockLenY - 1) / kBlockLenY; + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t block_y = 0; block_y < blocks_y; ++block_y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + float val = static_cast(input[y_pos * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + // NOTE: This reference function outputs contigous scale tensors. + // It calculates a naive scale data format. Strides are handled + // in comparison. + if (scale_inv != nullptr) { + scale_inv[block_y * blocks_x + block_x] = qscale_inv; + } + if (scale_inv_t != nullptr) { + scale_inv_t[block_x * blocks_y + block_y] = qscale_inv; + } + + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + if (output != nullptr) { + output[y_pos * width + x_pos] = quantize_element(input[y_pos * width + x_pos], qscale); + } + if (output_t != nullptr) { + output_t[x_pos * height + y_pos] = + quantize_element(input[y_pos * width + x_pos], qscale); + } + } + } + } + } +} + +template +void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method, + const InputType* input, + const std::pair& input_hw, + OutputType* output, float* scale_inv, OutputType* output_t, + float* scale_inv_t, const QuantizationOptions& opts) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + + constexpr size_t kBlockLenX = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_x_t = (height + kBlockLenX - 1) / kBlockLenX; + if (output != nullptr && scale_inv != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t y = 0; y < height; ++y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + float val = static_cast(input[y * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv[y + height * block_x] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + output[y * width + x_pos] = quantize_element(input[y * width + x_pos], qscale); + } + } + } + } + if (output_t != nullptr && scale_inv_t != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x_t = 0; block_x_t < blocks_x_t; ++block_x_t) { + for (size_t x = 0; x < width; ++x) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + float val = static_cast(input[x + y_pos * width]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv_t[x + width * block_x_t] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + output_t[x * height + y_pos] = quantize_element(input[y_pos * width + x], qscale); + } + } + } + } +} + +void compare_scaling_factors(const std::string& name, const float* test, const float* ref, + const size_t row_blocks, const size_t col_blocks, + const size_t test_stride, const size_t ref_stride) { + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i * test_stride + j; + const int ref_idx = i * ref_stride + j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, + const float* ref, const size_t rows, + const size_t col_blocks) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i + rows * j; + const int ref_idx = i + rows * j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +template +void runTestCase(const ProcessingMethod processing_method, const std::vector& shape, + const bool rowwise, const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_y = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_BLOCK_SCALING, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(blocks_y * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(blocks_y * blocks_x); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize(processing_method, input.rowwise_cpu_dptr(), + {rows, cols}, ref_output.get(), ref_scale_inv.get(), + ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + auto scale_align_stride = [](size_t inner_elements) -> size_t { + return ((inner_elements + 4u - 1u) / 4u) * 4u; + }; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors("scale_inv", output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), blocks_y, blocks_x, scale_align_stride(blocks_x), + blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors("scale_inv_t", output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), blocks_x, blocks_y, scale_align_stride(blocks_y), + blocks_y); + } +} + +template +void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, + const std::vector& shape, const bool rowwise, + const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_x_t = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_BLOCK_SCALING, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(rows * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(cols * blocks_x_t); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize_onedimensional_blocks( + processing_method, input.rowwise_cpu_dptr(), {rows, cols}, ref_output.get(), + ref_scale_inv.get(), ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv", + output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), rows, blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv_t", + output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), cols, blocks_x_t); + } +} + +std::vector> matrix_sizes = { + {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, + {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, + {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, +}; + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + // ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + + +std::vector amax_epsilons = { + 0.0f, + // Set large epsilon to get observable behavior. + 0.1f, +}; + +} // namespace + +class FusedCastFloat8BlockwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +class FusedCastFloat8VectorwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 2u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCase(processing_method, matrix_size, rowwise, colwise, + fill_case, q_opts);););); +} + +TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 1u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCaseOneDimensionalBlocks( + processing_method, matrix_size, rowwise, colwise, fill_case, q_opts);););); +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: + return "CAST_ONLY"; + // case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + // case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + // case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + // case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: + return ""; + } +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: + return "Identity"; + // case ActivationType::GeLU: return "GeLU"; + // case ActivationType::SiLU: return "SiLU"; + // case ActivationType::ReLU: return "ReLU"; + // case ActivationType::QGeLU: return "QGeLU"; + // case ActivationType::SReLU: return "SReLU"; + default: + return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8BlockwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8VectorwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 855d70856a..d3faac5c28 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -116,7 +117,8 @@ NVTEShape convertShape(const std::vector& shape) { } std::pair get_scales(const NVTEShape& shape, - const NVTEScalingMode scaling_mode) { + const NVTEScalingMode scaling_mode, + const int block_scaling_dim) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { scale_inv_meta ret; ret.shape = {1}; @@ -134,27 +136,19 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul,4ul}; + auto block_alignment = std::vector{128ul, 4ul}; { auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat8E8M0; @@ -164,6 +158,61 @@ std::pair get_scales(const NVTEShape& shape, return {ret_rowwise, ret_colwise}; } + if (scaling_mode == NVTE_BLOCK_SCALING) { + if (block_scaling_dim == 2) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + + return {ret_rowwise, ret_colwise}; + } else if (block_scaling_dim == 1) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_1 = first_dim; + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_1 = last_dim; + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + return {ret_rowwise, ret_colwise}; + } else { + NVTE_ERROR("Unsupported block scaling dim!"); + } + } NVTE_ERROR("Invalid scaling mode!"); } @@ -171,7 +220,8 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode) { + const NVTEScalingMode &scaling_mode, + const QuantizationOptions* q_opts) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -197,8 +247,12 @@ Tensor::Tensor(const std::string& name, NVTEShape normalized_shape = convertShape(normalized_shape_v); NVTEShape columnwise_shape{nullptr, 0}; + size_t block_scaling_dim = 0; + if (q_opts != nullptr) { + block_scaling_dim = q_opts->block_scaling_dim; + } std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { @@ -259,27 +313,34 @@ Tensor::Tensor(const std::string& name, std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, - tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(normalized_shape, tensor_.scaling_mode(), block_scaling_dim); auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_shape = rowwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape; if (rowwise) { - cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); + auto scale_dtype = rowwise_scale_meta.type; + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); } if (columnwise) { cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); - tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); + auto scale_dtype = colwise_scale_meta.type; + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } + if (q_opts != nullptr) { + tensor_.set_qopt_force_pow_2_scales(q_opts->force_pow_2_scales); + tensor_.set_qopt_amax_epsilon(q_opts->amax_epsilon); + tensor_.set_qopt_block_scaling_dim(q_opts->block_scaling_dim); + } } } @@ -311,7 +372,8 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -349,7 +411,8 @@ void Tensor::from_cpu() const { cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -368,7 +431,7 @@ void Tensor::from_cpu() const { void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; from_cpu(); } @@ -383,27 +446,29 @@ void Tensor::set_scale_inv(float scale_inv) { if (columnwise_) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(tensor_.shape(), tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { rowwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } } if (columnwise_) { auto num_scales = product(colwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { columnwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } @@ -413,23 +478,20 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); - new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, - static_cast(my_rowwise_data.dtype), + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), my_rowwise_data.shape); auto my_columnwise_data = tensor_.get_columnwise_data(); new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, static_cast(my_columnwise_data.dtype), my_columnwise_data.shape); auto other_amax = other.tensor_.get_amax(); - new_tensor.set_amax(other_amax.data_ptr, - static_cast(other_amax.dtype), + new_tensor.set_amax(other_amax.data_ptr, static_cast(other_amax.dtype), other_amax.shape); auto other_scale = other.tensor_.get_scale(); - new_tensor.set_scale(other_scale.data_ptr, - static_cast(other_scale.dtype), + new_tensor.set_scale(other_scale.data_ptr, static_cast(other_scale.dtype), other_scale.shape); auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, @@ -460,9 +522,7 @@ std::string to_string(const std::vector &v) { std::vector unravel(const size_t i, const NVTEShape &shape) { std::vector ret; size_t current_i = i; - for (size_t current = shape.ndim - 1; - current > 0; - --current) { + for (size_t current = shape.ndim - 1; current > 0; --current) { ret.push_back(current_i % shape.data[current]); current_i /= shape.data[current]; } @@ -705,7 +765,7 @@ void fillCase_special(Tensor *t) { }); } else { double minAbs = -2.0; - double maxAbs = 1.0; + double maxAbs = 1.0; if constexpr (Case != InputsFillCase::uniform) { minAbs = Quantized_Limits::ranges[Case]; maxAbs = Quantized_Limits::ranges[Case + 1]; @@ -764,14 +824,13 @@ void setRandomScaleInv(Tensor *t) { } bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; +int32_t getDeviceComputeCapability() { + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; } size_t first_dimension(const std::vector &shape) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 4352056ddb..8f26ac7419 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,21 +95,29 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} Tensor() {} @@ -136,25 +144,19 @@ class Tensor { if (scale_inv != nullptr) { cudaFree(scale_inv); } - if (columnwise_data_ptr != nullptr){ + if (columnwise_data_ptr != nullptr) { cudaFree(columnwise_data_ptr); } - if (columnwise_scale_inv != nullptr){ + if (columnwise_scale_inv != nullptr) { cudaFree(columnwise_scale_inv); } } - NVTETensor data() const noexcept { - return tensor_.data(); - } + NVTETensor data() const noexcept { return tensor_.data(); } - NVTEShape rowwise_shape() const noexcept { - return tensor_.get_rowwise_data().shape; - } + NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; } - NVTEShape columnwise_shape() const noexcept { - return tensor_.get_columnwise_data().shape; - } + NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; } NVTEShape rowwise_scale_inv_shape() const { NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); @@ -221,6 +223,8 @@ class Tensor { T *rowwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -232,6 +236,8 @@ class Tensor { T *columnwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -459,6 +465,7 @@ extern std::vector all_fp_types; bool isFp8Type(DType type); int32_t getDeviceComputeCapability(); +constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; } // namespace test diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py new file mode 100644 index 0000000000..72cb062c31 --- /dev/null +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -0,0 +1,361 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import dataclasses +import math +import torch +from typing import Optional, Protocol, Tuple + + +@dataclasses.dataclass() +class QuantizeResult: + data: torch.Tensor + scale: torch.Tensor + data_t: Optional[torch.Tensor] + scale_t: Optional[torch.Tensor] + + +# FIXME(kwyss): Put this in a common location for per-tensor current +# scaling reference +def _scale_from_amax_tensor( + x_dtype: torch.dtype, + amax: torch.Tensor, + quant_dtype: torch.dtype, + *, + eps: float, + pow_2_scales: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Derives quantization and dequantization from amax and options. + + Reference implementation for scale calculation. + + Returns: + - scale: quantization scales + - scale_inv: dequantization scales + - amax: Amax tensor with updates made for extrema values. + """ + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax + + +@dataclasses.dataclass() +class CuBLASScaleMunger: + + def munge_scale_shapes_for_backend( + self, + unmunged: QuantizeResult, + tile_shape: Tuple[int, int], + ) -> QuantizeResult: + """ + cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed + so that for an (M, N) tensor, the scales are (RounUpDiv(N, 128), M) + + For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4)) + format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required + """ + if tile_shape[0] != 1: + # 2D block quantized tensor needs padding for cuBLAS GEMM. + def _munge_scale_tensor(s: torch.Tensor) -> torch.Tensor: + M, K = s.shape + if K % 4 == 0: + return s + k_pad = 4 - (K % 4) + return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous() + + s = _munge_scale_tensor(unmunged.scale) + if unmunged.scale_t is None: + s_t = None + else: + s_t = _munge_scale_tensor(unmunged.scale_t) + return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + + # 1D block quantized tensors needs transpose to prepare for the GEMM. + s = unmunged.scale.transpose(-1, -2).contiguous() + if unmunged.scale_t is None: + s_t = None + else: + s_t = unmunged.scale_t.transpose(-1, -2).contiguous() + return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + + def demunge_scale_shape_from_backend( + cls, + qtensor_shape: Tuple[int, int], + scales: torch.Tensor, + tile_shape: Tuple[int, int], + ) -> torch.Tensor: + """ + Inverse operation of munge_scale_shapes_for_backend + """ + if tile_shape[0] != 1: + # 2D block quantized tensor may need padding stripped off + derived_scale_k_shape = math.ceil(qtensor_shape[1] / tile_shape[1]) + M, K = scales.shape + if derived_scale_k_shape == K: + return scales + else: + return scales[:, :derived_scale_k_shape].contiguous() + return scales.transpose(-1, -2).contiguous() + + +@dataclasses.dataclass() +class BlockwiseQuantizerReference: + """ + A reference QuantizeOp for subchannel/block hybrid quantization. + + Defers to ref GEMMs and quantizization formatting based on the backend. + """ + + def __init__(self) -> None: + self.scale_munger = CuBLASScaleMunger() + + @classmethod + def _quantize_square_block_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + pad_m_k = [0, 0] + if K % tile_len != 0: + pad_m_k[1] = tile_len - (K % tile_len) + if M % tile_len != 0: + pad_m_k[0] = tile_len - (M % tile_len) + + unpadded_m, unpadded_k = M, K + if pad_m_k[0] != 0 or pad_m_k[1] != 0: + x = torch.nn.functional.pad( + x, (0, pad_m_k[1], 0, pad_m_k[0]), mode="constant", value=0 + ).contiguous() + M, K = x.shape + + x_tiled = x.reshape(M // tile_len, tile_len, K // tile_len, tile_len) + amax_grid = ( + torch.abs(x_tiled.transpose(-3, -2)) + .reshape(M // tile_len, K // tile_len, tile_len**2) + .amax(dim=-1) + ).float() + dtype_max = torch.finfo(quant_dtype).max + + scale, scale_inv, _ = _scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + if unpadded_k != K or unpadded_m != M: + qx = qx[:unpadded_m, :unpadded_k].contiguous() + if return_transpose: + # Valid because of square block sizes + qx_t = qx.transpose(-1, -2).contiguous() + scale_inv_t = scale_inv.transpose(-1, -2).contiguous() + else: + qx_t = None + scale_inv_t = None + + return QuantizeResult(data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t) + + @classmethod + def _quantize_vectorwise_reference( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + pow_2_scales: bool, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + M, K = x.shape + dtype_max = torch.finfo(quant_dtype).max + x_tiled = x.reshape(M, K // tile_len, tile_len) + amax_grid = torch.abs(x_tiled).amax(dim=-1).float() + scale, scale_inv, _ = _scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + return qx, scale_inv + + @classmethod + def _quantize_vector_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + if K % tile_len == 0: + qref_input = x + else: + pad_amount = tile_len - (K % tile_len) + pad = (0, pad_amount) + qref_input = torch.nn.functional.pad(x, pad, mode="constant", value=0) + qout_padded, scale_inv = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if K % tile_len == 0: + qout = qout_padded + else: + qout = qout_padded[:, :K].contiguous() + + if return_transpose: + if M % tile_len == 0: + qref_input = x.transpose(-1, -2).contiguous() + else: + amount_to_pad = tile_len - (M % tile_len) + pad = (0, amount_to_pad) + qref_input = torch.nn.functional.pad( + x.transpose(-1, -2), pad, mode="constant", value=0 + ).contiguous() + qout_t_padded, scale_inv_t = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if M % tile_len == 0: + qout_t = qout_t_padded + else: + qout_t = qout_t_padded[:, :M].contiguous() + else: + qout_t, scale_inv_t = None, None + + return QuantizeResult(data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t) + + def ref_dequantize_rowwise( + self, + q: torch.Tensor, + quant_tile_shape: Tuple[int, int], + s: torch.Tensor, + dtype: torch.dtype, + ) -> torch.Tensor: + assert q.dim() == 2 + q_M, q_K = q.shape + s = self.scale_munger.demunge_scale_shape_from_backend((q_M, q_K), s, quant_tile_shape) + assert len(s.shape) == 2 + m_tiles, k_tiles = s.shape + M, K = q.shape + unpadded_m, unpadded_k = M, K + if M % quant_tile_shape[0] != 0 or K % quant_tile_shape[1] != 0: + m_pad_amount = (quant_tile_shape[0] - (M % quant_tile_shape[0])) % quant_tile_shape[0] + k_pad_amount = (quant_tile_shape[1] - (K % quant_tile_shape[1])) % quant_tile_shape[1] + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + M, K = q.shape + q_tiled = q.reshape(m_tiles, quant_tile_shape[0], k_tiles, quant_tile_shape[1]) + result = q_tiled.to(dtype) * s.reshape(m_tiles, 1, k_tiles, 1) + result = result.view(M, K).to(dtype) + if M != unpadded_m or K != unpadded_k: + result = result[:unpadded_m, :unpadded_k].contiguous() + return result + + def quantize( + self, + x: torch.Tensor, + quant_dtype: torch.dtype, + return_transpose: bool = False, + eps: float = 0.0, + pow_2_scales: bool = False, + quant_tile_shape: Tuple[int, int] = (128, 128), + ) -> QuantizeResult: + # sanity checks + assert x.dim() == 2 + assert x.dtype in ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, + ), "Unsupported input dtype." + assert quant_dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ), "Unsupported quant dtype." + + assert quant_tile_shape in ((1, 128), (128, 128)) + if quant_tile_shape[0] == 1: + # Quantize row-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_vector_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[1], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) + else: + # Quantize block-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_square_block_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[0], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py new file mode 100644 index 0000000000..16647184d6 --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -0,0 +1,291 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple +import math +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from tests.pytorch.references.blockwise_quantizer_reference import ( + BlockwiseQuantizerReference, + QuantizeResult, +) + + +def initialize_for_many_scales( + x_shape_2d: Tuple[int, int], tile_shape: Tuple[int, int], *, dtype: torch.dtype, device: str +) -> torch.Tensor: + """ + Put separate distributions into each quantization tile + to avoid many tiles having similar scale values and + causing false passes. + """ + tile_grid_shape = ( + math.ceil(x_shape_2d[0] / tile_shape[0]), + math.ceil(x_shape_2d[1] / tile_shape[1]), + ) + # Arbitrary size + max_val = 8192.0 + # Make a uniform distribution of [-max_val, max_val] + tile_extrema = torch.rand(*tile_grid_shape, dtype=dtype) * max_val * 2 - max_val + result = torch.empty(x_shape_2d, dtype=dtype, device=device) + tile_elements = tile_shape[0] * tile_shape[1] + for i in range(tile_grid_shape[0]): + for j in range(tile_grid_shape[1]): + target = tile_extrema[i, j].item() + step = target / (tile_elements) + if target == 0: + tile = torch.zeros(tile_shape, dtype=dtype, device=device) + else: + tile = torch.arange(0.0, target, step=step, dtype=dtype, device=device) + tile = tile.reshape(*tile_shape) + min_dst_vals = (i * tile_shape[0], j * tile_shape[1]) + max_dst_vals = ( + min((i + 1) * tile_shape[0], x_shape_2d[0]), + min((j + 1) * tile_shape[1], x_shape_2d[1]), + ) + max_src_vals = ( + max_dst_vals[0] - min_dst_vals[0], + max_dst_vals[1] - min_dst_vals[1], + ) + result[min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1]] = tile[ + : max_src_vals[0], : max_src_vals[1] + ] + return result + + +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (300, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0, 1e-12], ids=["eps_0", "eps_1e-12"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim = 1 + elif tile_size == (128, 128): + block_scaling_dim = 2 + else: + raise ValueError("Non support tile size") + # This test runs a comparison of the ref class versus the class using + # CUDA kernels to quantize. They should quantize identically for pixels + # that are not DC values in the scale factor shape. + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device) + + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) + + assert x_fp8_sut._rowwise_data is not None + qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + assert x_fp8_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv + qx_t = x_fp8_sut._columnwise_data + sx_t = x_fp8_sut._columnwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, + ) + qx_ref, sx_ref, qx_t_ref, sx_t_ref = ( + qresult_ref.data, + qresult_ref.scale, + qresult_ref.data_t, + qresult_ref.scale_t, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + if tile_size[0] != 1: + # Zero out values that are don't care values + # cuBLAS has padding of 2D tensors. + scale_mask = torch.ones( + (math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx, scale_mask, None, None), tile_size + ).scale + sx = sx * scale_mask + + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + assert qx_t is not None + qx_t = qx_t.view(dtype=quant_dtype) + assert qx_t_ref is not None + assert sx_t is not None + assert sx_t_ref is not None + if tile_size[0] != 1: + scale_mask = torch.ones( + (math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])), + device=sx_t.device, + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx_t, scale_mask, None, None), tile_size + ).scale + sx_t = sx_t * scale_mask + torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) + else: + # should be None + assert qx_t is None and qx_t_ref is None + assert sx_t is None and sx_t_ref is None + + +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (1, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0, math.pow(2, -125)], ids=["eps_0", "eps_small"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128)]) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +def test_quantization_block_tiling_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + pow_2_scales: bool, + tile_size: Tuple[int, int], + extrema_high: bool, +) -> None: + # This test runs a single tile through a quantizer as a way to test + # branch coverage of scale computation. + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim = 1 + elif tile_size == (128, 128): + block_scaling_dim = 2 + else: + raise ValueError("Non support tile size") + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + return_transpose = False + # Input + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + # Run cast and transpose kernel + # Internal call ops.quantize_tensorwise + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) + qx = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + sx = x_fp8_sut._rowwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, + ) + qx_ref, sx_ref = ( + qresult_ref.data, + qresult_ref.scale, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + + if extrema_high: + expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max + if pow_2_scales: + expected_value = math.floor(math.log2(expected_value)) + expected_value = math.pow(2.0, expected_value) + expected_value = 1 / expected_value + elif not extrema_high and eps == 0: + expected_value = 1.0 + else: + assert not extrema_high + # eps is small enough to trigger inf in quant_dtype_max / eps + if pow_2_scales: + expected_value = math.pow(2.0, -127) + else: + expected_value = 1 / torch.finfo(x_dtype).max + torch.testing.assert_close( + sx, + torch.tensor([expected_value], device=sx.device).reshape(1, 1), + atol=0.0, + rtol=0.0, + ) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py new file mode 100644 index 0000000000..7058fdb22f --- /dev/null +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections.abc import Iterable +import io +import math +from typing import Any, Dict, List, Tuple, Union + +import pytest +import torch + +import transformer_engine.common.recipe +import transformer_engine.pytorch as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +import transformer_engine_torch as tex + +# PyTorch tensor dtypes +_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] +# TE FP8 dtypes +_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + +# Numerical tolerances with FP8 types +_tols: Dict[tex.DType, Dict[str, float]] = { + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 +} + + +def _to_list(x: Union[Iterable, Any]) -> List: + """Convert to list if iterable, otherwise put in singleton list""" + if isinstance(x, Iterable): + return list(x) + else: + return [x] + + +# Types that can be interpreted as tensor dims +DimsType = Union[Iterable[int], int] + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFloat8BlockwiseTensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_constructor( + self, + dims: DimsType = 1, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + dtype: torch.dtype = torch.float32, + ) -> None: + """Call constructor and perform sanity checks""" + dims = _to_list(dims) + + rowwise = True + columnwise = True + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, rowwise=rowwise, columnwise=columnwise + ) + + scale_dims = quantizer.get_scale_shape(dims, columnwise=False) + columnwise_scale_dims = quantizer.get_scale_shape(dims, columnwise=True) + columnwise_dims = quantizer.get_columnwise_shape(dims) + tensor = Float8BlockwiseQTensor( + shape=dims, + dtype=dtype, + rowwise_data=torch.zeros(dims, device="cuda", dtype=torch.uint8), + rowwise_scale_inv=torch.zeros(scale_dims, device="cuda", dtype=torch.float32), + columnwise_data=torch.zeros(columnwise_dims, device="cuda", dtype=torch.uint8), + columnwise_scale_inv=torch.zeros( + columnwise_scale_dims, device="cuda", dtype=torch.float32 + ), + fp8_dtype=fp8_dtype, + quantizer=quantizer, + ) + assert list(tensor.size()) == dims, "Incorrect dims" + assert tensor.dtype == dtype, "Incorrect nominal dtype" + assert tensor.is_cuda, "Incorrect device" + + def _test_quantize_dequantize( + self, + quantizer: Float8BlockQuantizer, + dtype: torch.dtype = torch.float32, + dims: DimsType = (23, 128), + rtol: float = 0.0, + atol: float = 0.0, + dequant_columnwise: bool = False, + use_cpp_allocation: bool = False, + ) -> None: + """Check numerical error when casting to FP8 and back""" + dims = _to_list(dims) + + # Initialize random data + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref_cuda = x_ref.to("cuda") + + # Cast to FP8 and back + if not use_cpp_allocation: + x_fp8 = quantizer.make_empty(shape=dims, device="cuda") + quantizer.update_quantized(x_ref_cuda, x_fp8) + else: + # This codepath allows the CPP binding to allocate the output + # tensor + x_fp8 = tex.quantize(x_ref_cuda, quantizer, None, None) + if dequant_columnwise: + # Strip out rowwise data to verify dequantization of + # columnwise data. + x_fp8.update_usage(rowwise_usage=False, columnwise_usage=True) + x_fp8 = x_fp8.dequantize(dtype=dtype).cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, rtol=rtol, atol=atol) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_quantize_dequantize_dtypes( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=False, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims( + self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool + ) -> None: + atol = _tols[tex.DType.kFloat8E4M3]["atol"] + rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + ) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims_cpp_allocate_output( + self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + use_cpp_allocation=True, + ) + + # FIXME(kwyss): Add some testing for other tensor operations. + # - basic_ops + # - in_place_ops + # - serialization + # - set_data diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index deeb3c3862..4a7df7d1aa 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -58,6 +58,8 @@ list(APPEND transformer_engine_SOURCES transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise.cu activation/gelu.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ac58398551..36106e0110 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -101,6 +101,10 @@ struct Tensor { NVTEScalingMode scaling_mode; + float amax_epsilon; + bool force_pow_2_scales; + int block_scaling_dim; + Tensor() : data(), columnwise_data(), @@ -108,7 +112,10 @@ struct Tensor { scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), - scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + scaling_mode(NVTE_DELAYED_TENSOR_SCALING), + amax_epsilon(0.0), + force_pow_2_scales(false), + block_scaling_dim(scaling_mode == NVTE_BLOCK_SCALING ? 2 : 0) {} int numel() const { size_t acc = 1; @@ -125,6 +132,33 @@ struct Tensor { return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; } + bool supports_force_pow_2_scales_qopt() const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return true; + default: + return false; + } + } + + bool supports_amax_epsilon_qopt() const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return true; + default: + return false; + } + } + + bool supports_block_scaling_dim(int block_scaling_dim) const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return block_scaling_dim == 1 || block_scaling_dim == 2; + default: + return false; + } + } + DType dtype() const { if (has_data()) return data.dtype; if (has_columnwise_data()) return columnwise_data.dtype; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 3234e087c3..f19465c44b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -81,6 +81,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, const int k, const int lda, const int ldb) { using namespace transformer_engine; + // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. + // Must either force them both into a common block scaling mode or loosen this + // restriction. NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); @@ -90,6 +93,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.lda = lda; ret.ldb = ldb; + // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases + // or need to be treated as `is_tensor_scaling`. if (is_tensor_scaling(A.scaling_mode)) { ret.A = A.data.dptr; ret.A_scale_inv = A.scale_inv.dptr; @@ -244,6 +249,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); + // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized + // GEMM types. + // Scaling factors. #if CUDA_VERSION >= 12080 cublasLtMatmulMatrixScale_t scaling_mode; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index d57975b2f4..9be0e14d8a 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -42,23 +42,25 @@ extern "C" { * of the output tensor should be set to 0. */ -/*! \brief Casts input tensor to FP8/MXFP8. +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. + * the MXFP8 block quantization of the specified shape of the block will be used. + * If the scaling mode of the output tensor is set to NVTE_BLOCK_SCALING, + * blockwise float8 scaling will be used. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output FP8/MXFP8/BlockwiseFP8 tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. - * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output quantized tensor. * \param[out] noop Noop tensor. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 70086a1811..bae30e3e05 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -80,8 +80,13 @@ enum NVTEScalingMode { /*! Single scale per block of 32 elements consecutive in either rowwise or columnwise direction */ NVTE_MXFP8_1D_SCALING = 1, - NVTE_INVALID_SCALING = 2, - NVTE_NO_SCALING = 3 + /*! Tensor is split into NxN quantization tiles or 1xN quantization tiles, + which each yield a scale. The block_scaling_dim property of the quantizer + selects the granularity. + */ + NVTE_BLOCK_SCALING = 2, + NVTE_INVALID_SCALING = 3, + NVTE_NO_SCALING = 4 }; /*! \brief TE Tensor type @@ -235,6 +240,63 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, const NVTEBasicTensor *param); +/*! \brief Set a quantization option for whether to force power of 2 scales. + * + * \param[in/out] tensor Tensor. + * \param[in] zero_if_false Whether to force power of 2 scales. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect. + */ +int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false); + +/*! \brief Set a quantization option for epsilon to set floor of amax. + * + * \param[in/out] tensor Tensor. + * \param[in] amax_epsilon Epsilon to use for amax calculation. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect. + */ +int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon); + +/*! \brief Set a quantization option to use 1D or 2D quantization blocks + * to scale the tensor. + * + * \param[in/out] tensor Tensor. + * \param[in] block_scaling_dim, 1D or 2D. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect or the number of dims is not supported. + */ +int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim); + +/*! \brief Get a quantization option for whether to force power of 2 scales. + * + * \param[in] tensor Tensor. + * + * \return zero if the tensor will not force power of 2 scales or if the + * setting is irrelevant. non-zero if the flag is configured. + */ +int nvte_get_qopt_force_pow_2_scales(NVTETensor tensor); + +/*! \brief Get a quantization option for amax epsilon. + * + * \param[in] tensor Tensor. + * + * \return amax_epsilon value or zero if not applicable. + */ +float nvte_get_qopt_amax_epsilon(const NVTETensor tensor); + +/*! \brief Get the number of dimensions in the quantization blocks. + * + * \param[in] tensor Tensor. + * + * \return zero if the quantization does not support the block_scaling_dim + * option or the block_scaling_dim configured. + */ +int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor); + /*! \brief Get a value of the parameter of the tensor. * * \param[in] tensor Tensor. @@ -660,6 +722,24 @@ class TensorWrapper { void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } + int set_qopt_force_pow_2_scales(bool flag) { + return nvte_set_qopt_force_pow_2_scales(tensor_, flag ? 1 : 0); + } + + int set_qopt_amax_epsilon(float eps) { return nvte_set_qopt_amax_epsilon(tensor_, eps); } + + int set_qopt_block_scaling_dim(int block_scaling_dim) { + return nvte_set_qopt_block_scaling_dim(tensor_, block_scaling_dim); + } + + bool get_qopt_force_pow_2_scales() const { + return nvte_get_qopt_force_pow_2_scales(tensor_) != 0; + } + + float get_qopt_amax_epsilon() const { return nvte_get_qopt_amax_epsilon(tensor_); } + + int get_qopt_block_scaling_dim() const { return nvte_get_qopt_block_scaling_dim(tensor_); } + static constexpr size_t defaultData = 1; static constexpr NVTEShape defaultShape = {&defaultData, 1}; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1f8bfca2c9..63d28c41cf 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -502,3 +502,48 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { delete reinterpret_cast(config); } } + +int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_force_pow_2_scales_qopt()) { + t.force_pow_2_scales = zero_if_false != 0; + return 0; + } else { + return 1; + } +} + +int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_amax_epsilon_qopt()) { + t.amax_epsilon = amax_epsilon; + return 0; + } else { + return 1; + } +} + +int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_block_scaling_dim(block_scaling_dim)) { + t.block_scaling_dim = block_scaling_dim; + return 0; + } else { + return 1; + } +} + +int nvte_get_qopt_force_pow_2_scales(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.force_pow_2_scales ? 1 : 0; +} + +float nvte_get_qopt_amax_epsilon(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.amax_epsilon; +} + +int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.block_scaling_dim; +} diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index ed9bd5f5f7..cf2cb15174 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -23,6 +23,18 @@ template +#include +#include + +#include +#include + +namespace transformer_engine { + +// Type trait for extreme values of fp8 types. +// Used in the calculation of scale factors +// as a constexpr lookup from e4m3 or e5m2 to +// the max finite value. +template +struct F8LimitsTrait; + +template <> +struct F8LimitsTrait<__nv_fp8_e4m3> { + static constexpr float max = 448.0f; +}; + +template <> +struct F8LimitsTrait<__nv_fp8_e5m2> { + static constexpr float max = 57344.0f; +}; + +// Type trait to resolve the max finite value +// represented by a input type to quantization. +// Or to represent max representable power of 2 +// finite value. +template +struct HighPrecisionFloatScaleLimitsTrait; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + static constexpr float max = std::numeric_limits::max(); +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 127 + static constexpr float max = 0x1.0p127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.(7 bits of 1) * 2 ^ 127 + static constexpr float max = 0x1.FEp127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 127 + static constexpr float max = 0x1.0p127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.(10 bits of 1) * 2 ^ 15 + static constexpr float max = 0x1.FFCp15; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 15 + static constexpr float max = 0x1.0p15; +}; + +// Calculate the quantization scale for an individual data element +// given the amax(abs(tile)) value for a given quantization tile. +// +// +// Arguments: +// IType: data type of the tensor being quantized (float or bf16) +// OType: quantized data type (e4m3 or e5m2) +// pow_2_scaling: Whether to force the scale to be a power of 2. +// amax: The evaluation of amax(abs(tile)) for the quantization tile. +// eps: An epsilon used as a floor for amax. +template +__device__ __forceinline__ float ComputeScale(const float amax, const float eps) { + constexpr float fp8_max = F8LimitsTrait::max; + + // Clamping amax to avoid division by small numbers + float amax_mod = fmaxf(amax, eps); + + // Handle overflow cases for non-clamped amax (eps is 0 or very small) + if (amax_mod == 0.f) { + // If amax is 0, return 1 + return 1.f; + } + // Compute scale factor + float scale = fp8_max / amax_mod; + + if (isinf(scale)) { + // If scale is infinity, return max value of IType + return HighPrecisionFloatScaleLimitsTrait::max; + } + if (scale == 0.0) { + // Case that amax is "inf". The frexp, ldexp logic changes 0.0 scales. + // Return 0.0 for 0.0 scale here is consistent with non-Power2Scaling model. + // quantization will remove signal from the tensor, + // this is bad for the model, but define pow2Scale behavior + // as returning 0.0 scale. amax calculation can + // improve the situation to avoid this by taking largest finite. + return scale; + } + if constexpr (Power2Scaling) { + // NOTE: using bit fiddling based on advice of Asit in this + // thread: https://nvidia.slack.com/archives/C06EDT7LZEW/p1738274404153439 + + // inf scales already early returned, as did nan scales. + // The cases to consider here are normals, zero, and subnormals. + // zero is not possible with current math as + // 448.0 / float_max == 1.31655e-36, which is the smallest + // possible scale given current dtypes. It is still in the normal + // fp32 range with an exponent of -120, so subnormals are also + // not possible. To handle normals, we can simply mask off the + // mantissa. + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + return scale; +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu new file mode 100644 index 0000000000..934cf8a5fb --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -0,0 +1,603 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/utils.cuh" +#include "compute_scale.cuh" + +#if (!defined(__CUDA_MINIMUM_ARCH__)) || \ + (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) +#define TMA_HW_SUPPORTED +#endif + +namespace transformer_engine { +namespace { + +#ifdef TMA_HW_SUPPORTED +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; +#endif + +// const values configuration + +constexpr size_t kThreadsPerWarp = 32; +#ifdef TMA_HW_SUPPORTED +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 32; +constexpr size_t WARP_TILE_DIM_Y = 64; +constexpr size_t THREAD_TILE_DIM_X = 16; +constexpr size_t THREAD_TILE_DIM_Y = 4; +#else +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 64; +constexpr size_t WARP_TILE_DIM_Y = 32; +constexpr size_t THREAD_TILE_DIM_X = 8; +constexpr size_t THREAD_TILE_DIM_Y = 8; +#endif + +#ifdef TMA_HW_SUPPORTED +constexpr size_t NUM_BYTES_PER_BANK = 4; +constexpr size_t NUM_BANKS_PER_SHARED_ELEM = THREAD_TILE_DIM_Y / NUM_BYTES_PER_BANK; +constexpr size_t SHARED_BLOCK_TILE_DIM_Y = BLOCK_TILE_DIM; +constexpr size_t SHARED_BLOCK_TILE_DIM_X_BANKS = + BLOCK_TILE_DIM / (NUM_BYTES_PER_BANK * NUM_BANKS_PER_SHARED_ELEM); +constexpr size_t NUM_BANKS_Y_IN_WARP = WARP_TILE_DIM_Y / NUM_BYTES_PER_BANK; +#endif +constexpr size_t ELE_PER_THREAD = THREAD_TILE_DIM_X * THREAD_TILE_DIM_Y; +constexpr size_t THREADS_PER_BLOCK = BLOCK_TILE_DIM * BLOCK_TILE_DIM / ELE_PER_THREAD; +constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X; +constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y; +constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK; + +constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X; +constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP; + +#define MIN(a, b) (a < b ? a : b) + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon, + const __grid_constant__ CUtensorMap tensor_map_output_t) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + // This is ONLY true if the input is a full tile + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_idx = + tile_id_y * BLOCK_TILE_DIM * row_length + tile_id_x * BLOCK_TILE_DIM; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + +// Step 1: Load a block tile of input data into thread tiles on registers +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from(input + thread_tile_start_idx + i * row_length); + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + OVecCast tmp_output_c; + + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length); + } + + // Step 4: store transpose into shared memory + if constexpr (kReturnTranspose) { +#ifdef TMA_HW_SUPPORTED + __shared__ alignas(128) + OVecTrans block_tile_trans_shared[SHARED_BLOCK_TILE_DIM_Y][SHARED_BLOCK_TILE_DIM_X_BANKS]; + OType(*block_tile_trans_shared_otype_ptr)[BLOCK_TILE_DIM] = + reinterpret_cast(block_tile_trans_shared); + +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + auto warp_id_in_block_x_ = warp_id_in_block_y; + auto warp_id_in_block_y_ = warp_id_in_block_x; + int row_idx = warp_id_in_block_y_ * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X + i; + int col_idx = + warp_id_in_block_x_ * (NUM_BANKS_Y_IN_WARP / NUM_BANKS_PER_SHARED_ELEM) + tid_in_warp_y; + block_tile_trans_shared[row_idx][col_idx] = thrd_tile_out_trans[i]; + } + + // Wait for shared memory writes to be visible to TMA engine. + cde::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Step 5: store transpose output + // Initiate TMA transfer to copy shared memory to global memory + if (threadIdx.x == 0) { + cde::cp_async_bulk_tensor_2d_shared_to_global( + &tensor_map_output_t, tile_id_y * BLOCK_TILE_DIM, tile_id_x * BLOCK_TILE_DIM, + block_tile_trans_shared_otype_ptr); + // Wait for TMA transfer to have finished reading shared memory. + // Create a "bulk async-group" out of the previous bulk copy operation. + cde::cp_async_bulk_commit_group(); + // Wait for the group to have completed reading from shared memory. + cde::cp_async_bulk_wait_group_read<0>(); + } +#else + // Step 4 Alternative (when TMA is not available, skip writing to shared memory) + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + thrd_tile_out_trans[i].store_to(output_t + thread_tile_t_start_idx + i * num_rows); + } +#endif + } +} + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_row_idx = tile_id_y * BLOCK_TILE_DIM; + const size_t block_tile_start_col_idx = tile_id_x * BLOCK_TILE_DIM; + const size_t block_tile_start_idx = + block_tile_start_row_idx * row_length + block_tile_start_col_idx; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + // handle non-full tile + // check for three cases: full thread tile, nonfull thread tile, empty thread tile + // for empty thread tile, directly write zero to the transposed shared mem buffer + // for nonfull thread tile, fill zero to thread tile and act as if it's full + const size_t thread_tile_start_row_idx = + tile_id_y * BLOCK_TILE_DIM + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP + + tid_in_warp_y * THREAD_TILE_DIM_Y; + const size_t thread_tile_start_col_idx = + tile_id_x * BLOCK_TILE_DIM + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X; + + const size_t thread_tile_end_row_idx = thread_tile_start_row_idx + THREAD_TILE_DIM_Y - 1; + const size_t thread_tile_end_col_idx = thread_tile_start_col_idx + THREAD_TILE_DIM_X - 1; + + bool full_thrd_tile = + (thread_tile_end_row_idx < num_rows) && (thread_tile_end_col_idx < row_length); + bool empty_thrd_tile = + (thread_tile_start_row_idx >= num_rows) || (thread_tile_start_col_idx >= row_length); + bool nonfull_thrd_tile = (!full_thrd_tile) && (!empty_thrd_tile); + + const size_t thread_tile_ncols = + MIN(THREAD_TILE_DIM_X, + (MIN(thread_tile_end_col_idx, row_length - 1) - thread_tile_start_col_idx + 1)); + const size_t thread_tile_nrows = + MIN(THREAD_TILE_DIM_Y, + (MIN(thread_tile_end_row_idx, num_rows - 1) - thread_tile_start_row_idx + 1)); + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + + if (!empty_thrd_tile) { + // Step 1: Load a block tile of input data into thread tiles on registers + // Edge case: nonfull thread tile case, will use the partial load function here + if (nonfull_thrd_tile) { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + thrd_tile_input[i].clear(); + } else { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + } + } else { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + THREAD_TILE_DIM_X); + } + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + // Edge case: in the non-full tile case, there are three subcases + // for full thread tile, it's the same thing here + // for nonfull thread tile, pay attention when saving tmp_output_c to global + // memory, cannot vec store_to, but need to elt store to for empty tile, + // it should not enter this step, skip to Step 4 + + // set thrd_tile_out_trans to all zero + if constexpr (kReturnTranspose) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + thrd_tile_out_trans[j].clear(); + } + } + + if (!empty_thrd_tile) { + OVecCast tmp_output_c; + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + continue; + } +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + + if constexpr (kReturnTranspose) { + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < thread_tile_ncols; i++) { + thrd_tile_out_trans[i].store_to_elts(output_t + thread_tile_t_start_idx + i * num_rows, 0, + thread_tile_nrows); + } + } + } +} + +PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { + void* driver_ptr = nullptr; + cudaDriverEntryPointQueryResult driver_status; + NVTE_CHECK_CUDA(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, + &driver_status)); + return reinterpret_cast(driver_ptr); +} + +template +CUtensorMap get_tensor_map(SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { + // example-begin create-tensor-map + CUtensorMap tensor_map_output_trans{}; + // rank is the number of dimensions of the array. + constexpr uint32_t rank = 2; + uint64_t size[rank] = {global_dim_x, global_dim_y}; // x, y + // The stride is the number of bytes to traverse from the first element of one row to the next. + // It must be a multiple of 16. + uint64_t stride[rank - 1] = {global_dim_x * sizeof(OutputType)}; + // The box_size is the size of the shared memory buffer that is used as the + // destination of a TMA transfer. + uint32_t box_size[rank] = {BLOCK_TILE_DIM, BLOCK_TILE_DIM}; + // The distance between elements in units of sizeof(element). A stride of 2 + // can be used to load only the real component of a complex-valued tensor, for instance. + uint32_t elem_stride[rank] = {1, 1}; + + // Get a function pointer to the cuTensorMapEncodeTiled driver API. + auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); + CUtensorMapDataType dataType; + + if constexpr (std::is_same_v || + std::is_same_v) { + dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else { + NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + } + + // Create the tensor descriptor. + CUresult res = cuTensorMapEncodeTiled( + &tensor_map_output_trans, // CUtensorMap *tensorMap, + dataType, + rank, // cuuint32_t tensorRank, + reinterpret_cast(tensor.dptr), // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + box_size, // const cuuint32_t *boxDim, + elem_stride, // const cuuint32_t *elementStrides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + // Swizzling can be used to avoid shared memory bank conflicts. + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + // Any element that is outside of bounds will be set to zero by the TMA transfer. + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + return tensor_map_output_trans; +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void nvte_quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow_2_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_transpose_square_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + } + + NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions."); + + size_t scale_k = scale_inv.shape[1]; + + const size_t scale_stride_x = 1; + const size_t scale_stride_y = scale_k; + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type."); + + NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions."); + + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); + const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + pow_2_scale, kPow2Scale, + + if (full_tile) { + CUtensorMap tensor_map_output_trans; + if constexpr (kReturnTranspose) { + tensor_map_output_trans = + get_tensor_map(output_t, num_rows, row_length); + } + block_scaled_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon, tensor_map_output_trans); + } else { + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon); + } // full-tile + + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu new file mode 100644 index 0000000000..d9676504ed --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -0,0 +1,528 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/utils.cuh" +#include "compute_scale.cuh" + +namespace transformer_engine { +namespace { + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr size_t kThreadsPerWarp = 32; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; +constexpr int kNumThreadsStore = kTileDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, + const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon) { + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory + const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory + const size_t num_ele = + c_g < row_length ? min((size_t)kNVecIn, row_length - c_g) : 0; // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + // Step 2: Cast and store to output_c + { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory + const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory + const size_t num_ele = + c_g < row_length ? min((size_t)kNVecOut, row_length - c_g) : 0; // For not aligned case + OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + // Step 2.4: Compute scale + CType scale = ComputeScale(amax, epsilon); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = (size_t)blockIdx.y * kTileDim + r_s; + size_t col_idx = (size_t)blockIdx.x; + if constexpr (kPermuteScale) { + size_t p_row = row_idx / kTileDim; + size_t p_col = col_idx; + size_t p_dep = row_idx % kTileDim; + size_t p_2d_stride = kTileDim * scale_stride_y; + tile_scales_inv_c[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; + } else { + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + output_vec.data.elt[i * kNVecSMem + j] = + static_cast(static_cast(smem_vec[i].data.elt[j]) * scale); + } + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if constexpr (kReturnTranspose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = (size_t)blockIdx.y * kTileDim + r_s; // Column in global memory + const size_t stride_g = (size_t)c_stride * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = + c_g < num_rows ? min((size_t)kNVecOut, num_rows - c_g) : 0; // For not aligned case + OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + // Step 3.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + // Step 3.4: Compute scale + CType scale = ComputeScale(amax, epsilon); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = (size_t)blockIdx.y; + if constexpr (kPermuteScale) { + size_t p_row = row_idx / kTileDim; + size_t p_col = col_idx; + size_t p_dep = row_idx % kTileDim; + size_t p_2d_stride = kTileDim * scale_t_stride_y; + tile_scales_inv_t[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; + } else { + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + output_vec.data.elt[i] = + static_cast(static_cast(smem_vec[i].data.elt[smem_idx]) * scale); + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void nvte_quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow2_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_transpose_vector_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + // Options for scale layout of cuBLAS GEMM kernel. + constexpr bool kPermuteScale = false; + bool permute_scale = false; + bool transpose_scales = true; + + NVTE_CHECK(input.shape.size() == output.shape.size(), + "Input and output must have the same shape."); + NVTE_CHECK((!transpose_scales || !permute_scale), + "Permute scale and transpose scales are mutually exclusive flags."); + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + if (permute_scale) { + NVTE_CHECK(scale_inv.shape.size() == 3, "scale_inv must have 3 dimensions."); + size_t scale_k = scale_inv.shape[1]; + NVTE_CHECK(scale_inv.shape[2] == kTileDim, "Scale inner dimension must be kTileDim."); + scale_stride_x = 1; + scale_stride_y = scale_k; + } else { + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2 when not permuting scale."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = 1; + scale_stride_y = scale_k; + if (transpose_scales) { + std::swap(scale_stride_x, scale_stride_y); + } + } + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + + if (permute_scale) { + NVTE_CHECK(scale_inv_t.shape.size() == 3, "Scale_t dimension must be 3."); + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + NVTE_CHECK(scale_inv_t.shape[2] == kTileDim, "Scale_t inner dimension must be kTileDim."); + } else { + NVTE_CHECK(scale_inv_t.shape.size() == 2, + "Scale_t dimension must be 2 when not permuting scale."); + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + if (transpose_scales) { + std::swap(scale_t_stride_x, scale_t_stride_y); + } + } + } + + const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); + const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + pow2_scale, kPow2Scale, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, kAligned, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + // shared memory must be requested up + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + &block_scaled_1d_cast_transpose_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); + } block_scaled_1d_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon);) // kAligned + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index ba2890ada3..c342805b1a 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1262,6 +1262,27 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe workspace_tensor, stream); break; } + case NVTE_BLOCK_SCALING: { + // FIXME(kwyss): Currently ignoring IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters. + if (output_tensor->block_scaling_dim == 2) { + nvte_quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); + } else if (output_tensor->block_scaling_dim == 1) { + nvte_quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); + } else { + NVTE_ERROR("Not supported block scaling dim."); + } + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e529289640..6e798ca748 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -349,6 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } } else { + // FIXME(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 3d807960ca..d1470e22e3 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -24,6 +24,12 @@ torch.bfloat16: tex.DType.kBFloat16, } +""" +This is a map: int -> torch.dtype +Used for resolving cuda extension types to torch. +Has one to one mapping with enum in +transformer_engine.h +""" TE_DType_To_Torch = { tex.DType.kByte: torch.uint8, tex.DType.kFloat8E4M3: torch.float8_e4m3fn, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2cf47e7399..1fae53b791 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -158,6 +158,36 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class Float8BlockQuantizer : public Quantizer { + public: + // Which float8 type is used for q data. + DType dtype; + + private: + // Options about how to quantize the tensor + // Quantization scales are rounded down to powers of 2. + bool force_pow_2_scales = false; + // Amax within quantization tile has a floor of epsilon. + float amax_epsilon = 0.0; + int block_scaling_dim = 2; + + public: + // Initializes from a python handle to a Float8BlockQuantizer + explicit Float8BlockQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_BLOCK_SCALING; } + + // Gets rowwise and columnwise_data from tensor and sets them on wrapper + void set_quantization_params(TensorWrapper* tensor) const override; + + // Create a python Float8BlockQuantized tensor and C++ wrapper + // for the tensor. Should set quantized data, scales for rowwise + // and optionally columnwise usage. std::pair create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data = std::nullopt) const override; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 097cf63acc..f5ccb2d29b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; +PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -61,9 +64,31 @@ void init_mxfp8_extension() { "Internal error: could not initialize pyTorch MXFP8 extension."); } +void init_float8blockwise_extension() { + if (Float8BlockwiseQTensorBasePythonClass) return; + auto fp8_module = + py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); + auto fp8_base_module = py::module_::import( + "transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); + Float8BlockwiseQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); + Float8BlockwiseQTensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); + Float8BlockwiseQTensorPythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); + + NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); + init_float8blockwise_extension(); } } // namespace transformer_engine::pytorch @@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 5121bc7f88..9ebb9d4f86 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -250,6 +250,134 @@ std::pair Float8CurrentScalingQuantizer::create_tenso tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); } this->set_quantization_params(&tensor); + + return {std::move(tensor), std::move(ret)}; +} + +Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); + this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); +} + +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // Change the rowwise and columnwise_data to the configured dtype. + // May be a switch between E5M2 and E4M3. + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); + + // Set options on TensorWrapper from quantization. + tensor->set_qopt_force_pow_2_scales(force_pow_2_scales); + tensor->set_qopt_amax_epsilon(amax_epsilon); + tensor->set_qopt_block_scaling_dim(block_scaling_dim); +} + + +std::pair Float8BlockQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(NVTE_BLOCK_SCALING); + at::TensorOptions opts; + at::TensorOptions scale_opts; + at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + + size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); + size_t m_dim = numel / k_dim; + constexpr size_t kBlockLen = 128; + + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data_rowwise = std::move(*rowwise_data); + } else { + data_rowwise = at::empty(torch_shape, opts); + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = m_dim; + } else { + NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor rowwise."); + } + scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts); + tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + + if (columnwise_usage) { + std::vector torch_columnwise_shape; + std::vector columnwise_shape; + NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape."); + if (torch_shape.size() > 0) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = k_dim; + } else { + NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor columnwise."); + } + data_colwise = at::empty(torch_columnwise_shape, opts); + scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts); + + tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); + tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + ret = Float8BlockwiseQTensorClass( + "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, + "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorPythonClass)); + ret = Float8BlockwiseQTensorClass( + "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, + "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, + "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } + return {std::move(tensor), std::move(ret)}; } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d5654fb43a..18a08605b6 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(NVTE_BLOCK_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); + + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + quantizer->set_quantization_params(&ret); + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index b0f55d7598..6cd62cd1d0 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -25,6 +25,9 @@ extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8QuantizerClass; +extern PyTypeObject *Float8BlockwiseQTensorPythonClass; +extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; +extern PyTypeObject *Float8BlockwiseQuantizerClass; void init_extension(); @@ -50,6 +53,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; } +inline bool IsFloat8BlockwiseQParams(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; +} + +inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || + Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; +} + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -61,6 +73,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati std::unique_ptr CreateMXFP8Params(const py::handle params); +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, + Quantizer *quantization_params); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } @@ -70,8 +85,10 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), - std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, - CreateQuantizer)}; + std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, + CreateQuantizer), + std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQParams, + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; } // namespace detail diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py new file mode 100644 index 0000000000..ffed102ee5 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -0,0 +1,246 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8BlockwiseQTensor""" + +from __future__ import annotations +import math +from typing import Optional, Dict, Any, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType_To_Torch + +from ..quantized_tensor import Quantizer + + +class Float8BlockwiseQTensorBase: + """Mixin class that holds data attributes of Float8BlockwiseQTensor. + + Float8BlockwiseQTensor inherits from the PyTorch tensor class and this + mixin class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Quantizer + _fp8_dtype: TE_DType + _rowwise_scale_inv: Optional[torch.Tensor] + _columnwise_scale_inv: Optional[torch.Tensor] + + def __new__( + cls, + *args, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: TE_DType, + quantizer: Quantizer, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: + """Prepare the tensor base for saving for backward + + FIXME(kwyss): Should this clear out data? + FIXME(kwyss): What about dq scales? + """ + tensors = [self._rowwise_data, self._columnwise_data] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: + """Takes dequantized columnwise data and permutes to a rowwise shape""" + if columnwise_dq.dim() < 2: + return columnwise_dq + permute_dims = [x for x in range(1, columnwise_dq.dim())] + permute_dims.append(0) + return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() + + def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + block_len = 128 + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + k_tiles, m = scale_inv.shape + if q_K % block_len != 0: + k_pad_amount = (block_len - (q_K % block_len)) % block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, 0), mode="constant", value=0 + ).contiguous() + _, padded_K = q.shape + q_tiled = q.reshape(q_M, k_tiles, block_len) + dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(m, k_tiles, 1) + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale + if padded_K != q_K: + result = result.reshape(q_M, padded_K)[:, :q_K] + result = result.to(dtype) + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + """ + block_len = 128 + assert self._quantizer is not None + if self._quantizer.block_scaling_dim != 2: + assert self._quantizer.block_scaling_dim == 1 + return self._dequantize_vectorwise(dtype=dtype) + + def format_scale_as_logical_shape(q_M, q_K, scales, block_len): + # The GEMM for 2D blocks required padding in the scales. + derived_scale_k_shape = math.ceil(q_K / block_len) + scale_M, scale_K = scales.shape + if derived_scale_k_shape == scale_K: + return scales + else: + return scales[:, :derived_scale_k_shape].contiguous() + return formatted_scales + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + formatted_scales = format_scale_as_logical_shape(q_M, q_K, scale_inv, block_len) + assert len(formatted_scales.shape) == 2 + m_tiles, k_tiles = formatted_scales.shape + unpadded_m, unpadded_k = q_M, q_K + m_block_len = block_len + k_block_len = block_len + if q_M % m_block_len != 0 or q_K % k_block_len != 0: + m_pad_amount = (m_block_len - (q_M % m_block_len)) % m_block_len + k_pad_amount = (k_block_len - (q_K % k_block_len)) % k_block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + padded_M, padded_K = q.shape + q_tiled = q.reshape(m_tiles, m_block_len, k_tiles, k_block_len) + + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + + result = q_tiled.view(torch_q_dtype).to(torch.float32) * formatted_scales.view( + m_tiles, 1, k_tiles, 1 + ) + result = result.view(padded_M, padded_K).to(dtype) + if padded_M != unpadded_m or padded_K != unpadded_k: + result = result[:unpadded_m, :unpadded_k] + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + if self._rowwise_data is not None: + return self._rowwise_data.size(*args, **kwargs) + else: + dims = list(self._columnwise_data.size(*args, **kwargs)) + reordered = [] + for i in range(1, len(dims)): + reordered.append(dims[i]) + reordered.append(dims[0]) + return torch.Size(reordered) + + def __repr__(self): + if self._rowwise_data is not None: + data = self.dequantize() + descriptor = "rowwise" + scale_inv = self._rowwise_scale_inv + else: + data = self.dequantize() + descriptor = "columnwise" + scale_inv = self._columnwise_scale_inv + return ( + "Float8BlockwiseQTensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"{descriptor}_scaled_data={data_rowwise}" + f"{descriptor}_scale_inv={scale_inv}, " + ")" + ) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py new file mode 100644 index 0000000000..8090ac20b8 --- /dev/null +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -0,0 +1,539 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data quantized with NxN tiles""" +from __future__ import annotations +from typing import Optional, Tuple, Iterable +import warnings + +import math +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..utils import devices_match, round_up_to_nearest_multiple + +aten = torch.ops.aten + + +class Float8BlockQuantizer(Quantizer): + """Builder class for tensors quantized with current scaling using + NxN quantization tilings to choose scale. + + This class is typically used to convert a high-precision tensor + (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + dtype: TE_DType + block_len: int + amax_epsilon: float + force_pow_2_scales: bool + block_scaling_dim: int + + def __init__( + self, + fp8_dtype: TE_DType, + *, + rowwise: bool, + columnwise: bool, + amax_epsilon: float = 0.0, + force_pow_2_scales: bool = False, + block_scaling_dim: int = 2, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + assert rowwise + self.dtype = fp8_dtype + self.block_len = 128 + self.force_pow_2_scales = force_pow_2_scales + self.amax_epsilon = amax_epsilon + self.block_scaling_dim = block_scaling_dim + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + assert isinstance( + dst, Float8BlockwiseQTensor + ), f"Cannot store quantized blockwise tensor in {type(dst)} type." + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + dst._fp8_dtype = self.dtype + return dst + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + # cuBLAS kernel format (for NxN by NxN and 1xN by NxN GEMMs) + # The scales for 2D block quantized tensors must have scales padded + # to multiples of 4 on the inner dimension. TODO: Verify whether outer + # dimension also to be padded for either GEMM. + if self.block_scaling_dim == 2: + logical_scale_shape = [1, 1] + for i in range(len(shape) - 1): + logical_scale_shape[-2] *= shape[i] + if len(shape) > 0: + logical_scale_shape[-1] = math.ceil(shape[-1] / self.block_len) + logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) + if columnwise: + tmp = logical_scale_shape[-1] + logical_scale_shape[-1] = logical_scale_shape[-2] + logical_scale_shape[-2] = tmp + logical_scale_shape[-1] = round_up_to_nearest_multiple(logical_scale_shape[-1], 4) + return tuple(logical_scale_shape) + else: + assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" + + logical_scale_shape = [1, 1] + for i in range(len(shape) - 1): + logical_scale_shape[-1] *= shape[i] + if len(shape) > 0: + logical_scale_shape[-2] = shape[-1] + if not columnwise: + logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) + return tuple(logical_scale_shape) + else: + logical_scale_shape[-1] = math.ceil(logical_scale_shape[-1] / self.block_len) + return (logical_scale_shape[1], logical_scale_shape[0]) + + def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: + if len(shape) == 0: + return tuple() + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> Float8BlockwiseQTensor: + """Construct quantized tensor with uninitialized data""" + if device is None: + device = torch.device("cuda") + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty( + self.get_columnwise_shape(shape), dtype=torch.uint8, device=device + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, + dtype=torch.float32, + device=device, + ) + + # Construct FP8 tensor + return Float8BlockwiseQTensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # NOTE: This interface is specific to requirements like delayed scaling + # where state from an estimator influences distribution parameters. + pass + + +class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): + """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + FP8 data in a uint8 tensor matching shape of dequantized tensor. + rowwise_scale_inv: torch.Tensor + FP32 dequantization scales in GEMM format for dequantizing rowwise_data. + columnwise_data: Optional[torch.Tensor] + FP8 data in a uint8 tensor matching shape of dequantized tensor transpose. + columnwise_scale_inv: Optional[torch.Tensor] + FP32 dequantization scales in GEMM format for dequantizing columnwise_data. + + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and + holds configuration about quantization and dequantization modes. + """ + + def __repr__(self, *, tensor_contents=None): + return ( + f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," + f" data={self.dequantize(dtype=self.dtype)})" + ) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + assert self._quantizer is not None + return self._quantizer + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> Float8BlockwiseQTensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + + By default the resulting tensor's dtype is the + Float8BlockwiseQTensor's pre-quantized dtype. + """ + if dtype is not None: + dequant_dtype = dtype + else: + dequant_dtype = self.dtype + return super().dequantize(dtype=dequant_dtype) + + def detach(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return Float8BlockwiseQTensor.make_like(self) + + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + """ + update_usage can be used to clear out one of two possible copies of the data. + """ + + assert ( + columnwise_usage or rowwise_usage + ), "Must retain some data either columnwise or rowwise" + + if columnwise_usage and rowwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage." + return + + if rowwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise usage." + self._columnwise_data = None + self._columnwise_scale_inv = None + return + if columnwise_usage: + assert ( + self._columnwise_data is not None and self._columnwise_scale_inv is not None + ), "Cannot update to columnwise usage." + self._rowwise_data = None + self._rowwise_scale_inv = None + return + + return + + def clone(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + rowwise_data = None + if self._rowwise_data is not None: + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8BlockwiseQTensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if ( + self._rowwise_data is not None + and self._rowwise_data.is_contiguous(memory_format=memory_format) + and ( + (self._columnwise_data is None) + or (self._columnwise_data.is_contiguous(memory_format=memory_format)) + ) + ): + return self + raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + return Float8BlockwiseQTensor( + shape=out_shape, + dtype=tensor.dtype, + rowwise_data=out_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + ) -> Float8BlockwiseQTensor: + """Build Float8BlockwiseQTensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8BlockwiseQTensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8BlockwiseQTensor._make_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + ), + ) + + def _get_data(self) -> Float8BlockwiseQTensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a Float8BlockwiseQTensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + + # Just copy FP8 data if other tensor is Float8BlockwiseQTensor + if isinstance(tensor, Float8BlockwiseQTensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + Float8BlockwiseQTensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(Float8BlockwiseQTensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting Float8BlockwiseQTensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + if shape != ctx.shape: + raise NotImplementedError("View not implemented.") + else: + return tensor + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("View bwd not implemented") + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape != ctx.shape: + raise NotImplementedError("Reshape not implemented yet.") + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("Reshape bwd not implemented yet.") + return grad.view(ctx.shape), None From bf9c137d83904686a6c63fa65d3bc0530e9fb486 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 6 Mar 2025 09:25:26 -0800 Subject: [PATCH 02/38] Apply linting changes. Signed-off-by: Keith Wyss --- .../common/transpose/compute_scale.cuh | 6 + .../quantize_transpose_vector_blockwise.cu | 54 ++++---- .../_internal/float8_blockwise_tensor_base.py | 40 +++--- .../pytorch/tensor/float8_blockwise_tensor.py | 120 +++++++++++++----- 4 files changed, 140 insertions(+), 80 deletions(-) diff --git a/transformer_engine/common/transpose/compute_scale.cuh b/transformer_engine/common/transpose/compute_scale.cuh index 7ef94d74fe..82ce6c5df7 100644 --- a/transformer_engine/common/transpose/compute_scale.cuh +++ b/transformer_engine/common/transpose/compute_scale.cuh @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + #ifndef TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ #define TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index d9676504ed..bacb757822 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -162,13 +162,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; const int c_s = - (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory - int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory - const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory - size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory - const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory - const size_t num_ele = - c_g < row_length ? min((size_t)kNVecIn, row_length - c_g) : 0; // For not aligned case + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = + static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = c_g < row_length ? min(static_cast(kNVecIn), row_length - c_g) + : 0; // For not aligned case const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory #pragma unroll for (int iter = 0; iter < num_iterations; ++iter) { @@ -207,13 +208,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; const int c_s = - (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory - int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory - const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory - size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory - const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory - const size_t num_ele = - c_g < row_length ? min((size_t)kNVecOut, row_length - c_g) : 0; // For not aligned case + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = + static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = c_g < row_length ? min(static_cast(kNVecOut), row_length - c_g) + : 0; // For not aligned case OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of // the first thread to do the reduction. @@ -259,8 +261,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } if (write_scale_inv) { CType scale_inv = 1.0 / scale; - size_t row_idx = (size_t)blockIdx.y * kTileDim + r_s; - size_t col_idx = (size_t)blockIdx.x; + size_t row_idx = static_cast(blockIdx.y) * kTileDim + r_s; + size_t col_idx = static_cast(blockIdx.x); if constexpr (kPermuteScale) { size_t p_row = row_idx / kTileDim; size_t p_col = col_idx; @@ -303,13 +305,15 @@ __global__ void __launch_bounds__(kThreadsPerBlock) constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); - const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory - int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory - size_t r_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Row in global memory - const size_t c_g = (size_t)blockIdx.y * kTileDim + r_s; // Column in global memory - const size_t stride_g = (size_t)c_stride * kNVecSMem * num_rows; // Stride in global memory - const size_t num_ele = - c_g < num_rows ? min((size_t)kNVecOut, num_rows - c_g) : 0; // For not aligned case + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = + static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = static_cast(blockIdx.y) * kTileDim + r_s; // Column in global memory + const size_t stride_g = + static_cast(c_stride) * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = c_g < num_rows ? min(static_cast(kNVecOut), num_rows - c_g) + : 0; // For not aligned case OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of // the first thread to do the reduction. @@ -353,8 +357,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } if (write_scale_inv) { CType scale_inv = 1.0 / scale; - size_t row_idx = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem + smem_idx; - size_t col_idx = (size_t)blockIdx.y; + size_t row_idx = static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = static_cast(blockIdx.y); if constexpr (kPermuteScale) { size_t p_row = row_idx / kTileDim; size_t p_col = col_idx; diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index ffed102ee5..42c236dbf5 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -9,10 +9,8 @@ from typing import Optional, Dict, Any, Tuple import torch -import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType_To_Torch from ..quantized_tensor import Quantizer @@ -93,7 +91,7 @@ def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch. """Takes dequantized columnwise data and permutes to a rowwise shape""" if columnwise_dq.dim() < 2: return columnwise_dq - permute_dims = [x for x in range(1, columnwise_dq.dim())] + permute_dims = list(range(1, columnwise_dq.dim())) permute_dims.append(0) return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() @@ -121,7 +119,7 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch orig_shape = q.shape q = q.reshape(q_M, q_K) - k_tiles, m = scale_inv.shape + k_tiles, scale_m = scale_inv.shape if q_K % block_len != 0: k_pad_amount = (block_len - (q_K % block_len)) % block_len q = torch.nn.functional.pad( @@ -129,7 +127,10 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch ).contiguous() _, padded_K = q.shape q_tiled = q.reshape(q_M, k_tiles, block_len) - dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(m, k_tiles, 1) + if scale_m > q_M: + # scale_m is 4 element aligned. + scale_inv = scale_inv[:, :q_M].contiguous() + dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1) torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale if padded_K != q_K: @@ -154,15 +155,13 @@ def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: assert self._quantizer.block_scaling_dim == 1 return self._dequantize_vectorwise(dtype=dtype) - def format_scale_as_logical_shape(q_M, q_K, scales, block_len): + def format_scale_as_logical_shape(q_K, scales, block_len): # The GEMM for 2D blocks required padding in the scales. derived_scale_k_shape = math.ceil(q_K / block_len) - scale_M, scale_K = scales.shape + _, scale_K = scales.shape if derived_scale_k_shape == scale_K: return scales - else: - return scales[:, :derived_scale_k_shape].contiguous() - return formatted_scales + return scales[:, :derived_scale_k_shape].contiguous() q_M, q_K = 1, 1 if self._rowwise_data is not None: @@ -185,7 +184,7 @@ def format_scale_as_logical_shape(q_M, q_K, scales, block_len): orig_shape = q.shape q = q.reshape(q_M, q_K) - formatted_scales = format_scale_as_logical_shape(q_M, q_K, scale_inv, block_len) + formatted_scales = format_scale_as_logical_shape(q_K, scale_inv, block_len) assert len(formatted_scales.shape) == 2 m_tiles, k_tiles = formatted_scales.shape unpadded_m, unpadded_k = q_M, q_K @@ -220,27 +219,22 @@ def size(self, *args, **kwargs): # pylint: disable=missing-function-docstring if self._rowwise_data is not None: return self._rowwise_data.size(*args, **kwargs) - else: - dims = list(self._columnwise_data.size(*args, **kwargs)) - reordered = [] - for i in range(1, len(dims)): - reordered.append(dims[i]) - reordered.append(dims[0]) - return torch.Size(reordered) + dims = list(self._columnwise_data.size(*args, **kwargs)) + reordered = [] + for i in range(1, len(dims)): + reordered.append(dims[i]) + reordered.append(dims[0]) + return torch.Size(reordered) def __repr__(self): if self._rowwise_data is not None: data = self.dequantize() descriptor = "rowwise" - scale_inv = self._rowwise_scale_inv else: data = self.dequantize() descriptor = "columnwise" - scale_inv = self._columnwise_scale_inv return ( "Float8BlockwiseQTensorBase(" f"fp8_dtype={self._fp8_dtype}, " - f"{descriptor}_scaled_data={data_rowwise}" - f"{descriptor}_scale_inv={scale_inv}, " - ")" + f"{descriptor}_scaled_data={data}" ) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 8090ac20b8..4b4dc3b94e 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -5,7 +5,6 @@ """Tensor class with FP8 data quantized with NxN tiles""" from __future__ import annotations from typing import Optional, Tuple, Iterable -import warnings import math import torch @@ -59,6 +58,29 @@ def update_quantized( *, noop_flag: Optional[torch.Tensor] = None, ) -> QuantizedTensor: + """Update the quantized tensor with data from the source tensor. + + This method quantizes the input tensor and stores the result in the destination tensor. + + Parameters + ---------- + src : torch.Tensor + Source tensor containing the data to be quantized + dst : QuantizedTensor + Destination tensor where the quantized data will be stored + noop_flag : Optional[torch.Tensor] + Optional flag tensor indicating whether to skip the quantization operation + + Returns + ------- + QuantizedTensor + The destination tensor containing the quantized data + + Raises + ------ + AssertionError + If the destination tensor is not a Float8BlockwiseQTensor + """ assert isinstance( dst, Float8BlockwiseQTensor ), f"Cannot store quantized blockwise tensor in {type(dst)} type." @@ -75,39 +97,70 @@ def update_quantized( return dst def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: - # cuBLAS kernel format (for NxN by NxN and 1xN by NxN GEMMs) - # The scales for 2D block quantized tensors must have scales padded - # to multiples of 4 on the inner dimension. TODO: Verify whether outer - # dimension also to be padded for either GEMM. + """Calculate the shape of the scaling tensor for blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For 2D tensors: + - If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4)) + - If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4)) + For 1D tensors: + - If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4)) + - If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4)) + """ + M, K = 1, 1 + for i in range(len(shape) - 1): + M *= shape[i] + if len(shape) > 0: + K = shape[-1] if self.block_scaling_dim == 2: - logical_scale_shape = [1, 1] - for i in range(len(shape) - 1): - logical_scale_shape[-2] *= shape[i] - if len(shape) > 0: - logical_scale_shape[-1] = math.ceil(shape[-1] / self.block_len) - logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) if columnwise: - tmp = logical_scale_shape[-1] - logical_scale_shape[-1] = logical_scale_shape[-2] - logical_scale_shape[-2] = tmp - logical_scale_shape[-1] = round_up_to_nearest_multiple(logical_scale_shape[-1], 4) - return tuple(logical_scale_shape) - else: - assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" - - logical_scale_shape = [1, 1] - for i in range(len(shape) - 1): - logical_scale_shape[-1] *= shape[i] - if len(shape) > 0: - logical_scale_shape[-2] = shape[-1] - if not columnwise: - logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) - return tuple(logical_scale_shape) - else: - logical_scale_shape[-1] = math.ceil(logical_scale_shape[-1] / self.block_len) - return (logical_scale_shape[1], logical_scale_shape[0]) + outer = math.ceil(K / self.block_len) + inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) + return (outer, inner) + outer = math.ceil(M / self.block_len) + inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) + return (outer, inner) + assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" + if columnwise: + outer = math.ceil(M / self.block_len) + inner = round_up_to_nearest_multiple(K, 4) + return (outer, inner) + outer = math.ceil(K / self.block_len) + inner = round_up_to_nearest_multiple(M, 4) + return (outer, inner) def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of a tensor after columnwise permutation. + + This method rearranges the dimensions of a tensor to be columnwise, + moving the last dimension to the front and keeping the order of other dimensions. + + Parameters + ---------- + shape : Iterable[int] + Original shape of the tensor + + Returns + ------- + Tuple[int, ...] + New shape with dimensions rearranged for columnwise layout. + For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1). + Returns empty tuple for empty input shape. + """ if len(shape) == 0: return tuple() colwise_shape = [shape[-1]] @@ -369,6 +422,7 @@ def _make_in_reduce_ex( columnwise_scale_inv: torch.Tensor, fp8_dtype: TE_DType, dtype: torch.dtype, + quantizer: Quantizer, ) -> Float8BlockwiseQTensor: """Build Float8BlockwiseQTensor, for use in __reduce__ @@ -383,6 +437,7 @@ def _make_in_reduce_ex( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, + quantizer=quantizer, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -396,6 +451,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._fp8_dtype, self.dtype, + self._quantizer, ), ) @@ -477,8 +533,7 @@ def forward( if shape != ctx.shape: raise NotImplementedError("View not implemented.") - else: - return tensor + return tensor @staticmethod def backward( @@ -526,6 +581,7 @@ def forward( break if shape != ctx.shape: raise NotImplementedError("Reshape not implemented yet.") + return tensor @staticmethod def backward( From 9ce20346df682e638fab24e0538f393b25fcf38d Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 27 Feb 2025 14:01:33 -0800 Subject: [PATCH 03/38] Alignment for 1D scaling for GEMM edge case. Signed-off-by: Keith Wyss --- .../cpp/operator/test_cast_float8blockwise.cu | 11 +++-- tests/cpp/test_common.cu | 4 +- .../blockwise_quantizer_reference.py | 46 ++++++++--------- .../test_float8_blockwise_scaling_exact.py | 49 +++++++++---------- .../pytorch/csrc/extensions/quantizer.cpp | 14 +++--- 5 files changed, 60 insertions(+), 64 deletions(-) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index 171d22be71..00a38af441 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -220,6 +220,10 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method } } +inline size_t scale_align_stride(size_t inner_elements) { + return ((inner_elements + 4u - 1u) / 4u) * 4u; +}; + void compare_scaling_factors(const std::string& name, const float* test, const float* ref, const size_t row_blocks, const size_t col_blocks, const size_t test_stride, const size_t ref_stride) { @@ -238,9 +242,10 @@ void compare_scaling_factors(const std::string& name, const float* test, const f void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, const float* ref, const size_t rows, const size_t col_blocks) { + const size_t test_stride = scale_align_stride(rows); for (int i = 0; i < rows; ++i) { for (int j = 0; j < col_blocks; ++j) { - const int test_idx = i + rows * j; + const int test_idx = i + test_stride * j; const int ref_idx = i + rows * j; ASSERT_FALSE(test[test_idx] != ref[ref_idx]) << "Error in " << name << std::endl @@ -306,10 +311,6 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector size_t { - return ((inner_elements + 4u - 1u) / 4u) * 4u; - }; - if (rowwise) { compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); compare_scaling_factors("scale_inv", output_c.rowwise_cpu_scale_inv_ptr(), diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index d3faac5c28..8224abd6c1 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -195,13 +195,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_1 = first_dim; auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(first_dim, 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_1 = last_dim; auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(last_dim, 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index 72cb062c31..d3460caea1 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -82,33 +82,26 @@ def munge_scale_shapes_for_backend( ) -> QuantizeResult: """ cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed - so that for an (M, N) tensor, the scales are (RounUpDiv(N, 128), M) + so that for an (M, N) tensor, the scales are (RoundUpDiv(N, 128), RoundUp(M, 4)) For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4)) format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required """ - if tile_shape[0] != 1: - # 2D block quantized tensor needs padding for cuBLAS GEMM. - def _munge_scale_tensor(s: torch.Tensor) -> torch.Tensor: - M, K = s.shape - if K % 4 == 0: - return s - k_pad = 4 - (K % 4) - return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous() - - s = _munge_scale_tensor(unmunged.scale) - if unmunged.scale_t is None: - s_t = None - else: - s_t = _munge_scale_tensor(unmunged.scale_t) - return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) - # 1D block quantized tensors needs transpose to prepare for the GEMM. - s = unmunged.scale.transpose(-1, -2).contiguous() + def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: + if transpose: + s = s.transpose(-1, -2).contiguous() + M, K = s.shape + if K % 4 == 0: + return s + k_pad = 4 - (K % 4) + return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous() + + s = _pad_inner_to_align(unmunged.scale, transpose=tile_shape[0] == 1) if unmunged.scale_t is None: s_t = None else: - s_t = unmunged.scale_t.transpose(-1, -2).contiguous() + s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1) return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) def demunge_scale_shape_from_backend( @@ -123,12 +116,15 @@ def demunge_scale_shape_from_backend( if tile_shape[0] != 1: # 2D block quantized tensor may need padding stripped off derived_scale_k_shape = math.ceil(qtensor_shape[1] / tile_shape[1]) - M, K = scales.shape - if derived_scale_k_shape == K: - return scales - else: - return scales[:, :derived_scale_k_shape].contiguous() - return scales.transpose(-1, -2).contiguous() + else: + derived_scale_k_shape = qtensor_shape[0] + M, K = scales.shape + if derived_scale_k_shape != K: + scales = scales[:, :derived_scale_k_shape].contiguous() + if tile_shape[0] == 1: + return scales.transpose(-1, -2).contiguous() + else: + return scales @dataclasses.dataclass() diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 16647184d6..e113f6ea8b 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -72,7 +72,7 @@ def initialize_for_many_scales( (1024, 256), # Padding required cases (256, 272), - (300, 300), + (303, 300), (305, 256), # Some larger tiles. (2000, 2000), @@ -155,17 +155,15 @@ def test_quantization_block_tiling_versus_reference( # Check torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) - if tile_size[0] != 1: - # Zero out values that are don't care values - # cuBLAS has padding of 2D tensors. - scale_mask = torch.ones( - (math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device - ) - scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( - QuantizeResult(qx, scale_mask, None, None), tile_size - ).scale - sx = sx * scale_mask - + # Zero out values that are don't care values + # Scale format has padding. + scale_mask = torch.ones( + (math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx, scale_mask, None, None), tile_size + ).scale + sx = sx * scale_mask torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) if return_transpose: @@ -174,15 +172,14 @@ def test_quantization_block_tiling_versus_reference( assert qx_t_ref is not None assert sx_t is not None assert sx_t_ref is not None - if tile_size[0] != 1: - scale_mask = torch.ones( - (math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])), - device=sx_t.device, - ) - scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( - QuantizeResult(qx_t, scale_mask, None, None), tile_size - ).scale - sx_t = sx_t * scale_mask + scale_mask = torch.ones( + (math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])), + device=sx_t.device, + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx_t, scale_mask, None, None), tile_size + ).scale + sx_t = sx_t * scale_mask torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) else: @@ -195,14 +192,14 @@ def test_quantization_block_tiling_versus_reference( "M, N", [ # full tile cases - (1, 128), + (128, 128), ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0, math.pow(2, -125)], ids=["eps_0", "eps_small"]) @pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) -@pytest.mark.parametrize("tile_size", [(1, 128)]) +@pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) def test_quantization_block_tiling_extrema_versus_reference( x_dtype: torch.dtype, @@ -266,7 +263,7 @@ def test_quantization_block_tiling_extrema_versus_reference( # Check torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) - torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(sx.flatten()[0], sx_ref.flatten()[0], atol=0.0, rtol=0.0) if extrema_high: expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max @@ -284,8 +281,8 @@ def test_quantization_block_tiling_extrema_versus_reference( else: expected_value = 1 / torch.finfo(x_dtype).max torch.testing.assert_close( - sx, - torch.tensor([expected_value], device=sx.device).reshape(1, 1), + sx.flatten()[0], + torch.tensor(expected_value, device=sx.device), atol=0.0, rtol=0.0, ) diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 9ebb9d4f86..e4e7d896d7 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -316,7 +316,7 @@ std::pair Float8BlockQuantizer::create_tensor( sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); } else if (block_scaling_dim == 1) { sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; - sinv1 = m_dim; + sinv1 = roundup(m_dim, 4); } else { NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor rowwise."); } @@ -347,7 +347,7 @@ std::pair Float8BlockQuantizer::create_tensor( sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); } else if (block_scaling_dim == 1) { sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; - sinv1 = k_dim; + sinv1 = roundup(k_dim, 4); } else { NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor columnwise."); } @@ -430,8 +430,9 @@ std::pair MXFP8Quantizer::create_tensor( auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{sinv0, sinv1}); + tensor.set_rowwise_scale_inv( + rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{static_cast(sinv0), static_cast(sinv1)}); } if (columnwise_usage) { @@ -441,8 +442,9 @@ std::pair MXFP8Quantizer::create_tensor( columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); - tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{sinv0, sinv1}); + tensor.set_columnwise_scale_inv( + columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{static_cast(sinv0), static_cast(sinv1)}); } this->set_quantization_params(&tensor); From 86d4be85e826b163ee6f391e04cc2cfd2a98b82d Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 14:16:23 -0700 Subject: [PATCH 04/38] MR feedback. Signed-off-by: Keith Wyss --- transformer_engine/common/transpose/compute_scale.cuh | 6 +++--- transformer_engine/common/util/cast_kernels.cuh | 4 +++- transformer_engine/common/util/dequantize_kernels.cuh | 2 +- .../tensor/_internal/float8_blockwise_tensor_base.py | 7 ++++--- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/transpose/compute_scale.cuh b/transformer_engine/common/transpose/compute_scale.cuh index 82ce6c5df7..0f17829fb2 100644 --- a/transformer_engine/common/transpose/compute_scale.cuh +++ b/transformer_engine/common/transpose/compute_scale.cuh @@ -114,9 +114,9 @@ __device__ __forceinline__ float ComputeScale(const float amax, const float eps) return scale; } if constexpr (Power2Scaling) { - // NOTE: using bit fiddling based on advice of Asit in this - // thread: https://nvidia.slack.com/archives/C06EDT7LZEW/p1738274404153439 - + // NOTE: using bit fiddling rather than pow2, exp to + // be exact. + // // inf scales already early returned, as did nan scales. // The cases to consider here are normals, zero, and subnormals. // zero is not possible with current math as diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index c342805b1a..1cc4ae9873 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1263,7 +1263,9 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe break; } case NVTE_BLOCK_SCALING: { - // FIXME(kwyss): Currently ignoring IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters. + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING"); if (output_tensor->block_scaling_dim == 2) { nvte_quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 6e798ca748..c885c69333 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -349,7 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } } else { - // FIXME(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING + // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 42c236dbf5..8d2b2dad4c 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -69,10 +69,11 @@ def prepare_for_saving( ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: """Prepare the tensor base for saving for backward - FIXME(kwyss): Should this clear out data? - FIXME(kwyss): What about dq scales? + FIXME(kwyss): Set data tensors to None and consider saving/restoring scales. + test_numerics.py fails when tensors are cleared at the moment in C++ shape logic. """ - tensors = [self._rowwise_data, self._columnwise_data] + tensors = [self._rowwise_data, + self._columnwise_data] return tensors, self def restore_from_saved( From 03f88f46c06cc5de563cedf479e4ac24d728e44c Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 15:20:44 -0700 Subject: [PATCH 05/38] Change API name. Signed-off-by: Keith Wyss --- .../common/transpose/cast_transpose.h | 22 ++++++------ .../quantize_transpose_square_blockwise.cu | 34 +++++++++++-------- .../quantize_transpose_vector_blockwise.cu | 12 +++---- .../common/util/cast_kernels.cuh | 4 +-- 4 files changed, 38 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index cf2cb15174..298d087337 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -23,17 +23,17 @@ template -CUtensorMap get_tensor_map(SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { - // example-begin create-tensor-map +CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { + CUtensorMapDataType dataType; + if constexpr (std::is_same_v || + std::is_same_v) { + dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else { + NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + } + CUtensorMap tensor_map_output_trans{}; + // create_2D_tensor_map(tensor_map_output_trans, tensor, + // global_dim_y, global_dim_x, /*shmemY=*/ BLOCK_TILE_DIM, /*shmemX=*/ BLOCK_TILE_DIM, /*stride_elems=*/ global_dim_x, /*offset_elems=*/ 0, sizeof(OutputType)); + // return tensor_map_output_trans; // rank is the number of dimensions of the array. constexpr uint32_t rank = 2; uint64_t size[rank] = {global_dim_x, global_dim_y}; // x, y @@ -473,14 +483,8 @@ CUtensorMap get_tensor_map(SimpleTensor& tensor, size_t global_dim_x, size_t glo // Get a function pointer to the cuTensorMapEncodeTiled driver API. auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); - CUtensorMapDataType dataType; - if constexpr (std::is_same_v || - std::is_same_v) { - dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else { - NVTE_CHECK(false, "Invalid Output type (must be FP8)."); - } + // Create the tensor descriptor. CUresult res = cuTensorMapEncodeTiled( @@ -507,12 +511,12 @@ CUtensorMap get_tensor_map(SimpleTensor& tensor, size_t global_dim_x, size_t glo namespace transformer_engine::detail { -void nvte_quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, - SimpleTensor& scale_inv_t, SimpleTensor& output, - SimpleTensor& output_t, const float epsilon, - const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_transpose_square_blockwise); +void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow_2_scale, + cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_square_blockwise); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index bacb757822..de53071eda 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -400,12 +400,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) namespace transformer_engine::detail { -void nvte_quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, - SimpleTensor& scale_inv_t, SimpleTensor& output, - SimpleTensor& output_t, const float epsilon, - const bool return_transpose, const bool pow2_scale, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_transpose_vector_blockwise); +void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow2_scale, + cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_vector_blockwise); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 1cc4ae9873..ac75baa86c 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1267,14 +1267,14 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING"); if (output_tensor->block_scaling_dim == 2) { - nvte_quantize_transpose_square_blockwise( + quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, /*epsilon=*/output_tensor->amax_epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), output_tensor->force_pow_2_scales, stream); } else if (output_tensor->block_scaling_dim == 1) { - nvte_quantize_transpose_vector_blockwise( + quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, /*epsilon=*/output_tensor->amax_epsilon, From 60e86c098d778df7be563ca89ba3c0dadc6fdeb4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 15:23:53 -0700 Subject: [PATCH 06/38] Fix merge conflict with name change. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/csrc/pybind.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 6cd62cd1d0..c7b3167e78 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -53,7 +53,7 @@ inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; } -inline bool IsFloat8BlockwiseQParams(PyObject *obj) { +inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; } @@ -85,9 +85,9 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), - std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, + std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, CreateQuantizer), - std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQParams, + std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; } // namespace detail From 00dffe233cb4b060e6aced84ecc73caea235a2cd Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 15:32:29 -0700 Subject: [PATCH 07/38] Use common tensor map API. Signed-off-by: Keith Wyss --- .../quantize_transpose_square_blockwise.cu | 49 ++----------------- .../common/util/cast_kernels.cuh | 2 +- transformer_engine/pytorch/csrc/common.h | 4 +- .../pytorch/csrc/extensions/quantizer.cpp | 1 - .../_internal/float8_blockwise_tensor_base.py | 3 +- 5 files changed, 7 insertions(+), 52 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 8d142385b7..ca5633556d 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -446,14 +446,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { - void* driver_ptr = nullptr; - cudaDriverEntryPointQueryResult driver_status; - NVTE_CHECK_CUDA(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, - &driver_status)); - return reinterpret_cast(driver_ptr); -} - template CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { CUtensorMapDataType dataType; @@ -465,44 +457,9 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size } CUtensorMap tensor_map_output_trans{}; - // create_2D_tensor_map(tensor_map_output_trans, tensor, - // global_dim_y, global_dim_x, /*shmemY=*/ BLOCK_TILE_DIM, /*shmemX=*/ BLOCK_TILE_DIM, /*stride_elems=*/ global_dim_x, /*offset_elems=*/ 0, sizeof(OutputType)); - // return tensor_map_output_trans; - // rank is the number of dimensions of the array. - constexpr uint32_t rank = 2; - uint64_t size[rank] = {global_dim_x, global_dim_y}; // x, y - // The stride is the number of bytes to traverse from the first element of one row to the next. - // It must be a multiple of 16. - uint64_t stride[rank - 1] = {global_dim_x * sizeof(OutputType)}; - // The box_size is the size of the shared memory buffer that is used as the - // destination of a TMA transfer. - uint32_t box_size[rank] = {BLOCK_TILE_DIM, BLOCK_TILE_DIM}; - // The distance between elements in units of sizeof(element). A stride of 2 - // can be used to load only the real component of a complex-valued tensor, for instance. - uint32_t elem_stride[rank] = {1, 1}; - - // Get a function pointer to the cuTensorMapEncodeTiled driver API. - auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); - - - - // Create the tensor descriptor. - CUresult res = cuTensorMapEncodeTiled( - &tensor_map_output_trans, // CUtensorMap *tensorMap, - dataType, - rank, // cuuint32_t tensorRank, - reinterpret_cast(tensor.dptr), // void *globalAddress, - size, // const cuuint64_t *globalDim, - stride, // const cuuint64_t *globalStrides, - box_size, // const cuuint32_t *boxDim, - elem_stride, // const cuuint32_t *elementStrides, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - // Swizzling can be used to avoid shared memory bank conflicts. - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, - // Any element that is outside of bounds will be set to zero by the TMA transfer. - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - + create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x, + /*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM, + /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType)); return tensor_map_output_trans; } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index ac75baa86c..30dd03a804 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1265,7 +1265,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe case NVTE_BLOCK_SCALING: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING"); + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING"); if (output_tensor->block_scaling_dim == 2) { quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 1fae53b791..a03f9b2175 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -159,8 +159,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; }; class Float8BlockQuantizer : public Quantizer { diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index e4e7d896d7..74951d2714 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -281,7 +281,6 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const tensor->set_qopt_block_scaling_dim(block_scaling_dim); } - std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype, std::optional rowwise_data) const { using namespace pybind11::literals; diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 8d2b2dad4c..f681a0ad70 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -72,8 +72,7 @@ def prepare_for_saving( FIXME(kwyss): Set data tensors to None and consider saving/restoring scales. test_numerics.py fails when tensors are cleared at the moment in C++ shape logic. """ - tensors = [self._rowwise_data, - self._columnwise_data] + tensors = [self._rowwise_data, self._columnwise_data] return tensors, self def restore_from_saved( From f6b53920ebd83a14822cf1f845931657eae2aa26 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 16:06:22 -0700 Subject: [PATCH 08/38] Change API to use two scaling mode enums. Signed-off-by: Keith Wyss --- .../cpp/operator/test_cast_float8blockwise.cu | 6 +- tests/cpp/test_common.cu | 111 +++++++++--------- tests/cpp/test_common.h | 4 +- tests/pytorch/test_float8blockwisetensor.py | 7 +- transformer_engine/common/common.h | 19 +-- .../common/include/transformer_engine/cast.h | 2 +- .../transformer_engine/transformer_engine.h | 24 +--- .../common/transformer_engine.cpp | 15 --- .../common/util/cast_kernels.cuh | 39 +++--- transformer_engine/pytorch/csrc/common.h | 4 +- .../pytorch/csrc/extensions/quantizer.cpp | 8 +- .../csrc/extensions/type_converters.cpp | 4 +- .../_internal/float8_blockwise_tensor_base.py | 8 +- .../pytorch/tensor/float8_blockwise_tensor.py | 1 + 14 files changed, 110 insertions(+), 142 deletions(-) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index 00a38af441..ce3175aafd 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -272,7 +272,8 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector ref_output = std::make_unique(rows * cols); @@ -343,7 +344,8 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, Tensor input("input", shape, itype); Tensor grad("grad", shape, itype); - Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_BLOCK_SCALING, &opts); + Tensor output_c("output_c", shape, otype, rowwise, colwise, + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); Tensor output_dbias("output_dbias", {cols}, itype); std::unique_ptr ref_output = std::make_unique(rows * cols); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 8224abd6c1..4ecd51dca3 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -117,8 +117,7 @@ NVTEShape convertShape(const std::vector& shape) { } std::pair get_scales(const NVTEShape& shape, - const NVTEScalingMode scaling_mode, - const int block_scaling_dim) { + const NVTEScalingMode scaling_mode) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { scale_inv_meta ret; ret.shape = {1}; @@ -158,60 +157,57 @@ std::pair get_scales(const NVTEShape& shape, return {ret_rowwise, ret_colwise}; } - if (scaling_mode == NVTE_BLOCK_SCALING) { - if (block_scaling_dim == 2) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); + if (scaling_mode == NVTE_BLOCK_SCALING_2D) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); - scale_inv_meta ret_rowwise, ret_colwise; + scale_inv_meta ret_rowwise, ret_colwise; - { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; - } - { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; - ret_colwise.shape = {scale_dim_0, scale_dim_1}; - } - ret_rowwise.type = DType::kFloat32; - ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size = sizeof(float); - ret_colwise.type_size = sizeof(float); - - return {ret_rowwise, ret_colwise}; - } else if (block_scaling_dim == 1) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); - scale_inv_meta ret_rowwise, ret_colwise; + { + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); - { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(first_dim, 4) * 4; - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; - } - { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(last_dim, 4) * 4; - ret_colwise.shape = {scale_dim_0, scale_dim_1}; - } - ret_rowwise.type = DType::kFloat32; - ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size = sizeof(float); - ret_colwise.type_size = sizeof(float); - return {ret_rowwise, ret_colwise}; - } else { - NVTE_ERROR("Unsupported block scaling dim!"); + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(first_dim, 4) * 4; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } + { + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(last_dim, 4) * 4; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + return {ret_rowwise, ret_colwise}; } NVTE_ERROR("Invalid scaling mode!"); @@ -252,7 +248,7 @@ Tensor::Tensor(const std::string& name, block_scaling_dim = q_opts->block_scaling_dim; } std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { @@ -314,7 +310,7 @@ Tensor::Tensor(const std::string& name, } } else { auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(normalized_shape, tensor_.scaling_mode(), block_scaling_dim); + get_scales(normalized_shape, tensor_.scaling_mode()); auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_shape = rowwise_scale_meta.shape; @@ -339,7 +335,6 @@ Tensor::Tensor(const std::string& name, if (q_opts != nullptr) { tensor_.set_qopt_force_pow_2_scales(q_opts->force_pow_2_scales); tensor_.set_qopt_amax_epsilon(q_opts->amax_epsilon); - tensor_.set_qopt_block_scaling_dim(q_opts->block_scaling_dim); } } } @@ -373,7 +368,7 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); + get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -412,7 +407,7 @@ void Tensor::from_cpu() const { cudaMemcpyHostToDevice); } auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); + get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -448,7 +443,7 @@ void Tensor::set_scale_inv(float scale_inv) { } auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(tensor_.shape(), tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); + get_scales(tensor_.shape(), tensor_.scaling_mode()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); if (num_scales == 1) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 8f26ac7419..08df3cf7d1 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -223,7 +223,7 @@ class Tensor { T *rowwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); @@ -236,7 +236,7 @@ class Tensor { T *columnwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 7058fdb22f..15ccce6961 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -61,6 +61,7 @@ def test_constructor( dims: DimsType = 1, fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, dtype: torch.dtype = torch.float32, + is_2D_scaled: bool = True, ) -> None: """Call constructor and perform sanity checks""" dims = _to_list(dims) @@ -68,7 +69,10 @@ def test_constructor( rowwise = True columnwise = True quantizer = Float8BlockQuantizer( - fp8_dtype=fp8_dtype, rowwise=rowwise, columnwise=columnwise + fp8_dtype=fp8_dtype, + rowwise=rowwise, + columnwise=columnwise, + block_scaling_dim=2 if is_2D_scaled else 1, ) scale_dims = quantizer.get_scale_shape(dims, columnwise=False) @@ -84,6 +88,7 @@ def test_constructor( columnwise_scale_dims, device="cuda", dtype=torch.float32 ), fp8_dtype=fp8_dtype, + is_2D_scaled=is_2D_scaled, quantizer=quantizer, ) assert list(tensor.size()) == dims, "Incorrect dims" diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 36106e0110..47a84ad068 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -103,7 +103,6 @@ struct Tensor { float amax_epsilon; bool force_pow_2_scales; - int block_scaling_dim; Tensor() : data(), @@ -114,8 +113,7 @@ struct Tensor { columnwise_scale_inv(nullptr, {1}, DType::kFloat32), scaling_mode(NVTE_DELAYED_TENSOR_SCALING), amax_epsilon(0.0), - force_pow_2_scales(false), - block_scaling_dim(scaling_mode == NVTE_BLOCK_SCALING ? 2 : 0) {} + force_pow_2_scales(false) {} int numel() const { size_t acc = 1; @@ -134,7 +132,8 @@ struct Tensor { bool supports_force_pow_2_scales_qopt() const noexcept { switch (scaling_mode) { - case NVTE_BLOCK_SCALING: + case NVTE_BLOCK_SCALING_2D: + case NVTE_BLOCK_SCALING_1D: return true; default: return false; @@ -143,22 +142,14 @@ struct Tensor { bool supports_amax_epsilon_qopt() const noexcept { switch (scaling_mode) { - case NVTE_BLOCK_SCALING: + case NVTE_BLOCK_SCALING_2D: + case NVTE_BLOCK_SCALING_1D: return true; default: return false; } } - bool supports_block_scaling_dim(int block_scaling_dim) const noexcept { - switch (scaling_mode) { - case NVTE_BLOCK_SCALING: - return block_scaling_dim == 1 || block_scaling_dim == 2; - default: - return false; - } - } - DType dtype() const { if (has_data()) return data.dtype; if (has_columnwise_data()) return columnwise_data.dtype; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 9be0e14d8a..3c0d24df78 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -45,7 +45,7 @@ extern "C" { /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the MXFP8 block quantization of the specified shape of the block will be used. - * If the scaling mode of the output tensor is set to NVTE_BLOCK_SCALING, + * If the scaling mode of the output tensor is set to NVTE_BLOCK_SCALING_1D or NVTE_BLOCK_SCALING_2D, * blockwise float8 scaling will be used. * * \param[in] input Input tensor to be cast. diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index bae30e3e05..0e2b6cfcff 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -84,9 +84,10 @@ enum NVTEScalingMode { which each yield a scale. The block_scaling_dim property of the quantizer selects the granularity. */ - NVTE_BLOCK_SCALING = 2, - NVTE_INVALID_SCALING = 3, - NVTE_NO_SCALING = 4 + NVTE_BLOCK_SCALING_1D = 2, + NVTE_BLOCK_SCALING_2D = 3, + NVTE_INVALID_SCALING = 4, + NVTE_NO_SCALING = 5 }; /*! \brief TE Tensor type @@ -260,17 +261,6 @@ int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false); */ int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon); -/*! \brief Set a quantization option to use 1D or 2D quantization blocks - * to scale the tensor. - * - * \param[in/out] tensor Tensor. - * \param[in] block_scaling_dim, 1D or 2D. - * - * \return zero if the tensor supports this option and it was set. non-zero if - * call had no effect or the number of dims is not supported. - */ -int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim); - /*! \brief Get a quantization option for whether to force power of 2 scales. * * \param[in] tensor Tensor. @@ -728,18 +718,12 @@ class TensorWrapper { int set_qopt_amax_epsilon(float eps) { return nvte_set_qopt_amax_epsilon(tensor_, eps); } - int set_qopt_block_scaling_dim(int block_scaling_dim) { - return nvte_set_qopt_block_scaling_dim(tensor_, block_scaling_dim); - } - bool get_qopt_force_pow_2_scales() const { return nvte_get_qopt_force_pow_2_scales(tensor_) != 0; } float get_qopt_amax_epsilon() const { return nvte_get_qopt_amax_epsilon(tensor_); } - int get_qopt_block_scaling_dim() const { return nvte_get_qopt_block_scaling_dim(tensor_); } - static constexpr size_t defaultData = 1; static constexpr NVTEShape defaultShape = {&defaultData, 1}; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 63d28c41cf..2fa9fc3aba 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -523,16 +523,6 @@ int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon) { } } -int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim) { - auto &t = *reinterpret_cast(tensor); - if (t.supports_block_scaling_dim(block_scaling_dim)) { - t.block_scaling_dim = block_scaling_dim; - return 0; - } else { - return 1; - } -} - int nvte_get_qopt_force_pow_2_scales(const NVTETensor tensor) { const auto &t = *reinterpret_cast(tensor); return t.force_pow_2_scales ? 1 : 0; @@ -542,8 +532,3 @@ float nvte_get_qopt_amax_epsilon(const NVTETensor tensor) { const auto &t = *reinterpret_cast(tensor); return t.amax_epsilon; } - -int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); - return t.block_scaling_dim; -} diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 30dd03a804..458f2ff217 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1262,27 +1262,28 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe workspace_tensor, stream); break; } - case NVTE_BLOCK_SCALING: { + case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING"); - if (output_tensor->block_scaling_dim == 2) { - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/output_tensor->amax_epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), - output_tensor->force_pow_2_scales, stream); - } else if (output_tensor->block_scaling_dim == 1) { - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, - /*epsilon=*/output_tensor->amax_epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), - output_tensor->force_pow_2_scales, stream); - } else { - NVTE_ERROR("Not supported block scaling dim."); - } + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); + quantize_transpose_square_blockwise(input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); + quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, output_tensor->data, + output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index a03f9b2175..338f1fcbb1 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -180,7 +180,9 @@ class Float8BlockQuantizer : public Quantizer { // Initializes from a python handle to a Float8BlockQuantizer explicit Float8BlockQuantizer(const py::handle& quantizer); - NVTEScalingMode get_scaling_mode() const override { return NVTE_BLOCK_SCALING; } + NVTEScalingMode get_scaling_mode() const override { + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } // Gets rowwise and columnwise_data from tensor and sets them on wrapper void set_quantization_params(TensorWrapper* tensor) const override; diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 74951d2714..5826427574 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -278,7 +278,6 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const // Set options on TensorWrapper from quantization. tensor->set_qopt_force_pow_2_scales(force_pow_2_scales); tensor->set_qopt_amax_epsilon(amax_epsilon); - tensor->set_qopt_block_scaling_dim(block_scaling_dim); } std::pair Float8BlockQuantizer::create_tensor( @@ -291,7 +290,7 @@ std::pair Float8BlockQuantizer::create_tensor( numel *= s; } - TensorWrapper tensor(NVTE_BLOCK_SCALING); + TensorWrapper tensor((block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); at::TensorOptions opts; at::TensorOptions scale_opts; at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; @@ -366,7 +365,8 @@ std::pair Float8BlockQuantizer::create_tensor( ret = Float8BlockwiseQTensorClass( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, + "is_2D_scaled"_a = (block_scaling_dim == 2)); } else { py::handle Float8BlockwiseQTensorClass( reinterpret_cast(Float8BlockwiseQTensorPythonClass)); @@ -374,7 +374,7 @@ std::pair Float8BlockQuantizer::create_tensor( "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); } return {std::move(tensor), std::move(ret)}; diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index 18a08605b6..440e819ae7 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -86,7 +86,8 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { const DType dtype = tensor.attr("_fp8_dtype").cast(); - auto ret = TensorWrapper(NVTE_BLOCK_SCALING); + bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); @@ -101,7 +102,6 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); } - if (columnwise_usage) { const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index f681a0ad70..759d49b495 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -31,6 +31,7 @@ class Float8BlockwiseQTensorBase: _fp8_dtype: TE_DType _rowwise_scale_inv: Optional[torch.Tensor] _columnwise_scale_inv: Optional[torch.Tensor] + _is_2D_scaled: bool def __new__( cls, @@ -41,6 +42,7 @@ def __new__( columnwise_scale_inv: Optional[torch.Tensor], fp8_dtype: TE_DType, quantizer: Quantizer, + is_2D_scaled: bool, **kwargs, ): instance = super().__new__(cls, *args, **kwargs) @@ -50,6 +52,7 @@ def __new__( instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv + instance._is_2D_scaled = is_2D_scaled return instance @@ -62,6 +65,7 @@ def get_metadata(self) -> Dict[str, Any]: "columnwise_scale_inv": self._columnwise_scale_inv, "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, + "is_2D_scaled": self._is_2D_scaled, } def prepare_for_saving( @@ -150,9 +154,7 @@ def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: Construct plain PyTorch tensor from Float8BlockwiseQTensor """ block_len = 128 - assert self._quantizer is not None - if self._quantizer.block_scaling_dim != 2: - assert self._quantizer.block_scaling_dim == 1 + if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) def format_scale_as_logical_shape(q_K, scales, block_len): diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 4b4dc3b94e..89a5cc4a58 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -213,6 +213,7 @@ def make_empty( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, + is_2D_scaled==self.block_scaling_dim==2, requires_grad=requires_grad, ) From 33f2ed047f31fdb5d242820b7110bf731e485120 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 17:52:10 -0700 Subject: [PATCH 09/38] Fix typo. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 89a5cc4a58..1506880739 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -213,7 +213,7 @@ def make_empty( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, - is_2D_scaled==self.block_scaling_dim==2, + is_2D_scaled=self.block_scaling_dim == 2, requires_grad=requires_grad, ) From 125342da1744e1e1ed17c8427e0860dca2ec0027 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 18:04:15 -0700 Subject: [PATCH 10/38] Update some call sites. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 1506880739..943246f4d7 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -407,6 +407,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): columnwise_data=tensor._columnwise_data, columnwise_scale_inv=tensor._columnwise_scale_inv, quantizer=tensor._quantizer, + is_2D_scaled=tensor._is_2D_scaled, requires_grad=False, fp8_dtype=tensor._fp8_dtype, ) @@ -424,6 +425,7 @@ def _make_in_reduce_ex( fp8_dtype: TE_DType, dtype: torch.dtype, quantizer: Quantizer, + is_2D_scaled: bool, ) -> Float8BlockwiseQTensor: """Build Float8BlockwiseQTensor, for use in __reduce__ @@ -439,6 +441,7 @@ def _make_in_reduce_ex( columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, quantizer=quantizer, + is_2D_scaled=is_2D_scaled, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -453,6 +456,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._fp8_dtype, self.dtype, self._quantizer, + self._is_2D_scaled, ), ) From 035e1c973bb1d77866461881c9793856b71ffa55 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 12 Mar 2025 12:22:55 -0700 Subject: [PATCH 11/38] Tests for torch tensor API surface. Since the quantized tensor is a tensor subclass, these tests exercise torch hooks. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8blockwisetensor.py | 249 +++++++++++++++++- .../pytorch/tensor/float8_blockwise_tensor.py | 127 ++++----- 2 files changed, 300 insertions(+), 76 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 15ccce6961..316842b4f3 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -26,8 +26,8 @@ # Numerical tolerances with FP8 types _tols: Dict[tex.DType, Dict[str, float]] = { - tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 - tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.08), + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), } @@ -199,8 +199,243 @@ def test_quantize_dequantize_dims_cpp_allocate_output( use_cpp_allocation=True, ) - # FIXME(kwyss): Add some testing for other tensor operations. - # - basic_ops - # - in_place_ops - # - serialization - # - set_data + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None: + """Test data accessors of Float8BlockwiseQTensor""" + device = "cuda" + dtype = torch.bfloat16 + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + fp8_dtype = tex.DType.kFloat8E4M3 + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + # Create FP8 tensor + x_fp8 = quantizer.quantize(x_hp) + + x_recovered = x_fp8.data + torch.testing.assert_close(x_recovered, x_hp, **_tols[fp8_dtype]) + + x_fp8.data = y_hp + y_recovered = x_fp8.data + torch.testing.assert_close(y_recovered, y_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: + """Test serialization of Float8BlockwiseQTensor""" + device = "cuda" + dtype = torch.bfloat16 + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + # Create FP8 tensor + x_fp8 = quantizer.quantize(x_hp) + + # Save tensor + buffer = io.BytesIO() + torch.save(x_fp8, buffer) + + # Load tensor + buffer.seek(0) + x_fp8_loaded = torch.load(buffer, weights_only=False) + + # Test that loaded tensor matches original + assert isinstance(x_fp8_loaded, Float8BlockwiseQTensor) + torch.testing.assert_close(x_fp8_loaded._rowwise_data, x_fp8._rowwise_data) + torch.testing.assert_close(x_fp8_loaded._columnwise_data, x_fp8._columnwise_data) + torch.testing.assert_close(x_fp8_loaded._rowwise_scale_inv, x_fp8._rowwise_scale_inv) + torch.testing.assert_close(x_fp8_loaded._columnwise_scale_inv, x_fp8._columnwise_scale_inv) + torch.testing.assert_close(x_fp8_loaded.data, x_fp8.data) + assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled + assert x_fp8_loaded.dtype == x_fp8.dtype + assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype + + # Test that dequantized values match + x_fp8_dequant = x_fp8.dequantize() + x_fp8_loaded_dequant = x_fp8_loaded.dequantize() + torch.testing.assert_close(x_fp8_loaded_dequant, x_fp8_dequant) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_inplace_ops( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test in-place operations""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + # Test in-place add + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + x_fp8.add_(y_fp8) + torch.testing.assert_close(x_fp8.dequantize(), x_hp + y_hp, **_tols[fp8_dtype]) + + # Test in-place subtract + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + x_fp8.sub_(y_fp8) + torch.testing.assert_close(x_fp8.dequantize(), x_hp - y_hp, **_tols[fp8_dtype]) + + # Test in-place multiply + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + x_fp8.mul_(y_fp8) + torch.testing.assert_close(x_fp8.dequantize(), x_hp * y_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_out_of_place_ops( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test out-of-place operations""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + + # Test exact operations + torch.testing.assert_close(-x_fp8, -x_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(x_fp8.abs(), x_hp.abs(), **_tols[fp8_dtype]) + + # Test elementwise operations + torch.testing.assert_close(x_fp8 + y_fp8, x_hp + y_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(x_fp8 - y_fp8, x_hp - y_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(x_fp8 * y_fp8, x_hp * y_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_hp), **_tols[fp8_dtype]) + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8 + y_fp8, x_hp - y_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_view_same_shape( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test view operations that preserve tensor shape""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device) + quantizer.update_quantized(x_hp.clone(), x_fp8) + + # Test view with same shape + x_view = x_fp8.view(*dims) + torch.testing.assert_close(x_view.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_view.shape == x_fp8.shape, "Shape changed after view with same dims" + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_reshape_same_shape( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test reshape operations that preserve tensor shape""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device) + quantizer.update_quantized(x_hp.clone(), x_fp8) + + # Test reshape with same shape + x_reshape = x_fp8.reshape(*dims) + torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with same dims" + + # Test reshape with -1 canonicalization + new_dims = [-1, dims[1]] + x_reshape = x_fp8.reshape(*new_dims) + torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with -1" + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_reshape.dequantize(), -x_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_clone_detach( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test clone and detach operations""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.quantize(x_hp.clone()) + + # Test clone + x_clone = x_fp8.clone() + torch.testing.assert_close(x_clone.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_clone.shape == x_fp8.shape, "Shape changed after clone" + + # Test detach + x_detach = x_fp8.detach() + torch.testing.assert_close(x_detach.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_detach.shape == x_fp8.shape, "Shape changed after detach" + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_clone.dequantize(), -x_hp, **_tols[fp8_dtype]) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 943246f4d7..56c8a934ed 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -354,11 +354,28 @@ def clone(self) -> Float8BlockwiseQTensor: def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring - return _ViewFunc.apply(self, shape) + if not self.requires_grad: + # Autograd removes the quantized return type + # because of __torch_function__ in base class + # and torch._C._disabled_torch_function_impl + return _ViewFunc.forward(None, self, shape) + return super.view(self, *shape) def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring - return _ReshapeFunc.apply(self, shape) + if not self.requires_grad: + return _ReshapeFunc.forward(None, self, shape) + return super.reshape(self, *shape) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + return _ViewFunc.apply(args[0], *args[1:]) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) def contiguous( self, @@ -385,39 +402,10 @@ def clear(self): self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - - # View op - if func == aten.view.default: - tensor = args[0] - data = tensor._rowwise_data - out_data = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - out_shape = out_data.size() - return Float8BlockwiseQTensor( - shape=out_shape, - dtype=tensor.dtype, - rowwise_data=out_data, - rowwise_scale_inv=tensor._rowwise_scale_inv, - columnwise_data=tensor._columnwise_data, - columnwise_scale_inv=tensor._columnwise_scale_inv, - quantizer=tensor._quantizer, - is_2D_scaled=tensor._is_2D_scaled, - requires_grad=False, - fp8_dtype=tensor._fp8_dtype, - ) - - # Default case - return super().__torch_dispatch__(func, types, args, kwargs) - @classmethod def _make_in_reduce_ex( cls, + shape: torch.Size, rowwise_data: torch.Tensor, rowwise_scale_inv: torch.Tensor, columnwise_data: torch.Tensor, @@ -434,6 +422,7 @@ def _make_in_reduce_ex( """ return Float8BlockwiseQTensor( + shape=shape, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, fp8_dtype=fp8_dtype, @@ -449,6 +438,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: return ( Float8BlockwiseQTensor._make_in_reduce_ex, ( + self.shape, self._rowwise_data, self._rowwise_scale_inv, self._columnwise_data, @@ -462,7 +452,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: def _get_data(self) -> Float8BlockwiseQTensor: """Get tensor data property""" - return super().data + return self.dequantize() @torch.no_grad() def _set_data(self, tensor: torch.Tensor) -> None: @@ -476,41 +466,37 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Tensor device new_device = tensor.device if tensor.is_cuda else self.device + def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): + dst._rowwise_data = src._rowwise_data + dst._columnwise_data = src._columnwise_data + dst._quantizer = src._quantizer + dst._fp8_dtype = src._fp8_dtype + dst._rowwise_scale_inv = src._rowwise_scale_inv + dst._columnwise_scale_inv = src._columnwise_scale_inv + if dst.requires_grad != src.requires_grad: + dst.requires_grad_(requires_grad=src.requires_grad) + # Just copy FP8 data if other tensor is Float8BlockwiseQTensor - if isinstance(tensor, Float8BlockwiseQTensor): - if ( # pylint: disable=too-many-boolean-expressions - self.size() != tensor.size() - or self.stride() != tensor.stride() - or self.storage_offset() != tensor.storage_offset() - or self.dtype != tensor.dtype - or self.layout != tensor.layout - or not devices_match(self.device, new_device) - ): - dummy_tensor = torch.Tensor._make_wrapper_subclass( - Float8BlockwiseQTensor, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - dtype=tensor.dtype, - layout=tensor.layout, - requires_grad=tensor.requires_grad, - device=new_device, - ) - # pylint: disable=unnecessary-dunder-call - super(Float8BlockwiseQTensor, type(self)).data.__set__(self, dummy_tensor) - self._rowwise_data = tensor._rowwise_data - self._columnwise_data = tensor._columnwise_data - self._quantizer = tensor._quantizer - self._fp8_dtype = tensor._fp8_dtype - self._rowwise_scale_inv = tensor._rowwise_scale_inv - self._columnwise_scale_inv = tensor._columnwise_scale_inv + if ( + isinstance(tensor, Float8BlockwiseQTensor) + and self.size() == tensor.size() + and self.stride() == tensor.stride() + and self.storage_offset() == tensor.storage_offset() + and self.dtype == tensor.dtype + and self.layout == tensor.layout + and devices_match(self.device, new_device) + ): + _set_from_tensor(self, tensor) return + elif isinstance(tensor, Float8BlockwiseQTensor): + assert tensor._quantizer is not None, "Can't quantize without a quantizer" + quantizer = tensor._quantizer + else: + assert self._quantizer is not None, "Can't quantize without a quantizer" + quantizer = self._quantizer # Quantize to FP8 - assert self._quantizer is not None, "Can't quantize without a quantizer" - self.data = self._quantizer.quantize(tensor) - if self.requires_grad != tensor.requires_grad: - self.requires_grad_(requires_grad=tensor.requires_grad) + quantizer.update_quantized(tensor, self) # Cast to FP8 when setting Float8BlockwiseQTensor.data data = property(_get_data, _set_data) @@ -532,11 +518,12 @@ def forward( # pylint: disable=missing-function-docstring # Return input tensor if shape is not provided - ctx.shape = tensor.shape + if ctx is not None: + ctx.shape = tensor.shape if shape is None: return tensor - if shape != ctx.shape: + if list(shape) != list(tensor.shape): raise NotImplementedError("View not implemented.") return tensor @@ -568,7 +555,9 @@ def forward( # pylint: disable=missing-function-docstring # Return input tensor if shape is not provided - ctx.shape = tensor.shape + shape_arg = shape + if ctx is not None: + ctx.shape = tensor.shape if shape is None: return tensor @@ -579,12 +568,12 @@ def forward( shape = shape[0] if -1 in shape: shape = list(shape) - d_inferred = -math.prod(ctx.shape) // math.prod(shape) + d_inferred = -math.prod(tensor.shape) // math.prod(shape) for i, d in enumerate(shape): if d == -1: shape[i] = d_inferred break - if shape != ctx.shape: + if list(shape) != list(tensor.shape): raise NotImplementedError("Reshape not implemented yet.") return tensor From cc86afb8921f6090b159707f0f3765bb9ebf0220 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 12 Mar 2025 16:47:18 -0700 Subject: [PATCH 12/38] Reuse scale calculation between quantizer refs. Signed-off-by: Keith Wyss --- .../blockwise_quantizer_reference.py | 64 ++----------------- .../pytorch/references/quantize_scale_calc.py | 55 ++++++++++++++++ tests/pytorch/references/ref_per_tensor_cs.py | 59 +++-------------- 3 files changed, 68 insertions(+), 110 deletions(-) create mode 100644 tests/pytorch/references/quantize_scale_calc.py diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index d3460caea1..1fe9cfb28b 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -6,6 +6,9 @@ import math import torch from typing import Optional, Protocol, Tuple +from tests.pytorch.references.quantize_scale_calc import ( + scale_from_amax_tensor +) @dataclasses.dataclass() @@ -15,63 +18,6 @@ class QuantizeResult: data_t: Optional[torch.Tensor] scale_t: Optional[torch.Tensor] - -# FIXME(kwyss): Put this in a common location for per-tensor current -# scaling reference -def _scale_from_amax_tensor( - x_dtype: torch.dtype, - amax: torch.Tensor, - quant_dtype: torch.dtype, - *, - eps: float, - pow_2_scales: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Derives quantization and dequantization from amax and options. - - Reference implementation for scale calculation. - - Returns: - - scale: quantization scales - - scale_inv: dequantization scales - - amax: Amax tensor with updates made for extrema values. - """ - assert amax.dtype == torch.float, "amax must be a float tensor." - fp8_max = torch.finfo(quant_dtype).max - # Clamping amax to avoid division by small numbers - amax = torch.max(amax, torch.tensor(eps)) - - # Compute scale factor - scale = torch.div(fp8_max, amax) - # Note frexp doesn't give back inf for exponent with an inf input - # We take care of inf before pow_2_scales - scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) - if pow_2_scales: - # Calculate rounded down exponent - _, exp = torch.frexp(scale) - # Positive numbers are always returned as mant, exp with - # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with - # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because - # of the shift. Subnormal and zero cases need not be considered because - # the smallest possible result of fp8_max / amax is still normal. - exp = exp - 1 - # No subnormals and zero. - assert (exp > -127).all() - unity = torch.tensor([1.0], device=exp.device) - torch.ldexp(unity, exp, out=scale) - # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales - # Return 0.0 for 0.0 scale for consistency with non-pow2 scale - # calculation. - scale = torch.where(amax == float("inf"), 0.0, scale) - - # Handle overflow cases for amax zero causing NaN - scale = torch.where(amax == 0, 1.0, scale) - - # Compute scale_inv - scale_inv = torch.reciprocal(scale) - - return scale, scale_inv, amax - - @dataclasses.dataclass() class CuBLASScaleMunger: @@ -172,7 +118,7 @@ def _quantize_square_block_tiling( ).float() dtype_max = torch.finfo(quant_dtype).max - scale, scale_inv, _ = _scale_from_amax_tensor( + scale, scale_inv, _ = scale_from_amax_tensor( x_dtype=x.dtype, amax=amax_grid, quant_dtype=quant_dtype, @@ -209,7 +155,7 @@ def _quantize_vectorwise_reference( dtype_max = torch.finfo(quant_dtype).max x_tiled = x.reshape(M, K // tile_len, tile_len) amax_grid = torch.abs(x_tiled).amax(dim=-1).float() - scale, scale_inv, _ = _scale_from_amax_tensor( + scale, scale_inv, _ = scale_from_amax_tensor( x_dtype=x.dtype, amax=amax_grid, quant_dtype=quant_dtype, diff --git a/tests/pytorch/references/quantize_scale_calc.py b/tests/pytorch/references/quantize_scale_calc.py new file mode 100644 index 0000000000..eb2f424851 --- /dev/null +++ b/tests/pytorch/references/quantize_scale_calc.py @@ -0,0 +1,55 @@ +from typing import Tuple +import torch + +def scale_from_amax_tensor( + x_dtype: torch.dtype, + amax: torch.Tensor, + quant_dtype: torch.dtype, + *, + eps: float, + pow_2_scales: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Derives quantization and dequantization from amax and options. + + Reference implementation for scale calculation. + + Returns: + - scale: quantization scales + - scale_inv: dequantization scales + - amax: Amax tensor with updates made for extrema values. + """ + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index dad0c42357..ad8a4674da 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -6,49 +6,9 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType_To_Torch - - -# Compute scale and scale_inv from amax -def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): - # Clamping amax to avoid division by small numbers - amax = torch.max(amax, torch.tensor(eps)) - - # Compute scale factor - scale = torch.div(fp8_max, amax) - # Note frexp doesn't give back inf for exponent with an inf input - # We take care of inf before pow_2_scales - # option1: set scale to fp32 max when scale is inf - scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale) - # option2: when scale is inf, set scale to 1 - scale = torch.where(scale == torch.inf, 1.0, scale) - if pow_2_scales: - # Calculate rounded down exponent - _, exp = torch.frexp(scale) - # Positive numbers are always returned as mant, exp with - # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with - # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because - # of the shift. Subnormal and zero cases need not be considered because - # the smallest possible result of fp8_max / amax is still normal. - exp = exp - 1 - # No subnormals and zero. - assert (exp > -127).all() - # TODO: If/when adding a URM option an option is to cap to 126 - # rather than allowing the full range of FP32 (2 - 2^23) x 2^127 - # addresses cases where adding a mantissa overflows into inf scales. - # Not necessary currently without additional scale smudging options. - unity = torch.tensor([1.0], device=exp.device) - torch.ldexp(unity, exp, out=scale) - # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales - # Return 0.0 for 0.0 scale for consistency with non-pow2 scale - # calculation. - scale = torch.where(amax == float("inf"), 0.0, scale) - - # Handle overflow cases for amax zero causing NaN - scale = torch.where(amax == 0, 1.0, scale) - # Compute scale_inv - scale_inv = torch.reciprocal(scale) - - return scale, scale_inv +from tests.pytorch.references.quantize_scale_calc import ( + scale_from_amax_tensor +) # compute amax and scale @@ -56,14 +16,11 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): x_fp32 = x.to(torch.float32) amax = torch.amax(torch.abs(x_fp32)).view(1) assert amax.dtype == torch.float, "amax must be a float tensor." - fp8_max = torch.finfo(quant_dtype).max - - scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) - # Clamping amax to avoid division by small numbers - amax = torch.max(amax, torch.tensor(eps)) - - return scale, scale_inv, amax - + return scale_from_amax_tensor(torch.float32, + amax, + quant_dtype, + eps=eps, + pow_2_scales=pow_2_scales) def _multi_dim_transpose(tensor): # Get the number of dimensions From a815b2a58b7cd53d148b4ca7f599eb8b05359b18 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 12 Mar 2025 16:19:38 -0700 Subject: [PATCH 13/38] Save memory by dropping reference to saved tensors. Issues previously observed are solved. Signed-off-by: Keith Wyss --- .../references/blockwise_quantizer_reference.py | 5 ++--- tests/pytorch/references/quantize_scale_calc.py | 1 + tests/pytorch/references/ref_per_tensor_cs.py | 14 +++++--------- .../_internal/float8_blockwise_tensor_base.py | 8 +++----- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index 1fe9cfb28b..1f85a1bb6b 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -6,9 +6,7 @@ import math import torch from typing import Optional, Protocol, Tuple -from tests.pytorch.references.quantize_scale_calc import ( - scale_from_amax_tensor -) +from tests.pytorch.references.quantize_scale_calc import scale_from_amax_tensor @dataclasses.dataclass() @@ -18,6 +16,7 @@ class QuantizeResult: data_t: Optional[torch.Tensor] scale_t: Optional[torch.Tensor] + @dataclasses.dataclass() class CuBLASScaleMunger: diff --git a/tests/pytorch/references/quantize_scale_calc.py b/tests/pytorch/references/quantize_scale_calc.py index eb2f424851..bd1cb43356 100644 --- a/tests/pytorch/references/quantize_scale_calc.py +++ b/tests/pytorch/references/quantize_scale_calc.py @@ -1,6 +1,7 @@ from typing import Tuple import torch + def scale_from_amax_tensor( x_dtype: torch.dtype, amax: torch.Tensor, diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index ad8a4674da..7c0a161b1c 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -6,21 +6,17 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType_To_Torch -from tests.pytorch.references.quantize_scale_calc import ( - scale_from_amax_tensor -) +from tests.pytorch.references.quantize_scale_calc import scale_from_amax_tensor # compute amax and scale def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): x_fp32 = x.to(torch.float32) amax = torch.amax(torch.abs(x_fp32)).view(1) - assert amax.dtype == torch.float, "amax must be a float tensor." - return scale_from_amax_tensor(torch.float32, - amax, - quant_dtype, - eps=eps, - pow_2_scales=pow_2_scales) + return scale_from_amax_tensor( + torch.float32, amax, quant_dtype, eps=eps, pow_2_scales=pow_2_scales + ) + def _multi_dim_transpose(tensor): # Get the number of dimensions diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 759d49b495..9135237854 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -71,12 +71,10 @@ def get_metadata(self) -> Dict[str, Any]: def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: - """Prepare the tensor base for saving for backward - - FIXME(kwyss): Set data tensors to None and consider saving/restoring scales. - test_numerics.py fails when tensors are cleared at the moment in C++ shape logic. - """ + """Prepare the tensor base for saving for backward""" tensors = [self._rowwise_data, self._columnwise_data] + self._rowwise_data = None + self._columnwise_data = None return tensors, self def restore_from_saved( From 86dbaa8ab10fdd2f5067ba51ea381144dbf5fc68 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 13 Mar 2025 16:15:08 -0700 Subject: [PATCH 14/38] Remove constexpr parameters from kernel. Code size is reduced with fewer constexpr params. Signed-off-by: Keith Wyss --- .../cast/benchmark_quantize_transpose.py | 94 ++++++++++++ .../test_float8_blockwise_scaling_exact.py | 8 +- .../quantize_transpose_square_blockwise.cu | 94 ++++++------ .../quantize_transpose_vector_blockwise.cu | 135 ++++++------------ 4 files changed, 196 insertions(+), 135 deletions(-) create mode 100644 benchmarks/experimental/cast/benchmark_quantize_transpose.py diff --git a/benchmarks/experimental/cast/benchmark_quantize_transpose.py b/benchmarks/experimental/cast/benchmark_quantize_transpose.py new file mode 100644 index 0000000000..e0474a50f7 --- /dev/null +++ b/benchmarks/experimental/cast/benchmark_quantize_transpose.py @@ -0,0 +1,94 @@ +import argparse +import logging +import os +import pathlib +import sys + +import pandas as pd +import torch +import torch.utils.benchmark as benchmark +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +import transformer_engine_torch as tex + + +def run_kernel( + shape, + is_1d: bool, + return_transpose: bool, + input_dtype=torch.bfloat16, + quant_dtype=tex.DType.kFloat8E4M3, +): + # Generate random input data + M, K = shape + src = torch.randn([M, K], dtype=input_dtype, device="cuda") + + quantizer = Float8BlockQuantizer( + fp8_dtype=quant_dtype, + rowwise=True, + columnwise=return_transpose, + block_scaling_dim=1 if is_1d else 2, + ) + dst = quantizer.make_empty(shape, dtype=input_dtype, device="cuda") + + kernel_func = tex.quantize + stmt = "kernel_func(src, quantizer, dst)" + globals_dict = { + "kernel_func": kernel_func, + "quantizer": quantizer, + "src": src, + "dst": dst, + } + measurement = benchmark.Timer( + stmt=stmt, + globals=globals_dict, + num_threads=1, + setup="", + ).adaptive_autorange(threshold=0.1, min_run_time=1.0, max_run_time=5.0) + logging.info(f"Measurement: {measurement}") + timing_us = measurement.median * 1e6 + return timing_us + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + type=str, + default="benchmark_output/", + help="output path for report", + ) + args = parser.parse_args() + + shapes = [ + (256, 1024), + # (256, 1020), + # 8B model shape + (4096, 3072), + (4096, 4096), + (4096, 5440), + # 15B model shape + (16384, 1024), + (16384, 3072), + (16384, 6144), + (16384, 12288), + (16384, 24576), + ] + + dim_1d_opts = [True, False] + return_transpose_opts = [True, False] + + data = [] + for dim_1d_opt in dim_1d_opts: + for return_transpose in return_transpose_opts: + for shape in shapes: + print(f"Running 1D={dim_1d_opt} with shape {shape}") + timing_us = run_kernel(shape, dim_1d_opt, return_transpose) + data.append([dim_1d_opt, return_transpose, shape, timing_us]) + + df = pd.DataFrame(data=data, columns=["is_1d_kernel", "return_transpose", "shape", "timing_us"]) + logging.info(df) + pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) + report_file = pathlib.Path(args.output_dir) / f"{pathlib.Path(__file__).stem}_report.csv" + df.to_csv(report_file, index=False) + print(df) + logging.info(f"Report saved to {report_file}") diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index e113f6ea8b..3c9c857fd3 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -8,7 +8,7 @@ import torch import transformer_engine as te import transformer_engine_torch as tex - +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -19,6 +19,10 @@ QuantizeResult, ) +# TODO replace with call to fp8.py when recipe added. +recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 +reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + def initialize_for_many_scales( x_shape_2d: Tuple[int, int], tile_shape: Tuple[int, int], *, dtype: torch.dtype, device: str @@ -62,6 +66,7 @@ def initialize_for_many_scales( return result +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, N", [ @@ -188,6 +193,7 @@ def test_quantization_block_tiling_versus_reference( assert sx_t is None and sx_t_ref is None +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, N", [ diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index ca5633556d..5d2166a81d 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -65,7 +65,7 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP #define MIN(a, b) (a < b ? a : b) -template +template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, OType* const output_t, CType* const tile_scales_inv_c, @@ -73,7 +73,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, - const __grid_constant__ CUtensorMap tensor_map_output_t) { + const __grid_constant__ CUtensorMap tensor_map_output_t, + bool pow_2_scaling) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; @@ -152,7 +153,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) __syncthreads(); block_tile_amax = block_tile_amax_shared[0]; - block_tile_scale = ComputeScale(block_tile_amax, epsilon); + if (pow_2_scaling) { + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + } else { + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + } if (threadIdx.x == 0) { static_assert(std::is_same::value); @@ -244,12 +249,13 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } } -template +template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( const IType* const input, OType* const output_c, OType* const output_t, CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, - const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon) { + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, + bool pow_2_scaling) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; @@ -372,7 +378,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose __syncthreads(); block_tile_amax = block_tile_amax_shared[0]; - block_tile_scale = ComputeScale(block_tile_amax, epsilon); + if (pow_2_scaling) { + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + } else { + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + } if (threadIdx.x == 0) { static_assert(std::is_same::value); @@ -517,47 +527,43 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output.dtype, OutputType, - dim3 grid(num_blocks_x, num_blocks_y, 1); - const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transpose, kReturnTranspose, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - pow_2_scale, kPow2Scale, - - if (full_tile) { - CUtensorMap tensor_map_output_trans; - if constexpr (kReturnTranspose) { - tensor_map_output_trans = - get_tensor_map(output_t, num_rows, row_length); - } - block_scaled_cast_transpose_kernel - <<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, - epsilon, tensor_map_output_trans); - } else { - block_scaled_cast_transpose_kernel_notaligned - <<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, - epsilon); - } // full-tile - - ) // kPow2Scale - ) // kReturnTranspose - ) // OutputType - ) // InputType + dim3 grid(num_blocks_x, num_blocks_y, 1); + const bool full_tile = + row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; + + if (full_tile) { + CUtensorMap tensor_map_output_trans; + if (return_transpose) { + tensor_map_output_trans = + get_tensor_map(output_t, num_rows, row_length); + } + block_scaled_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + tensor_map_output_trans, pow_2_scale); + } else { + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + pow_2_scale); + } // full-tile + ) // return_transpose + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index de53071eda..2304084632 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -137,8 +137,7 @@ constexpr int kNumThreadsStore = kTileDim / kNVecOut; static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); -template +template __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, OType* const output_t, CType* const tile_scales_inv_c, @@ -146,7 +145,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, - const size_t scale_t_stride_y, const float epsilon) { + const size_t scale_t_stride_y, const float epsilon, + bool return_transpose, bool pow_2_scaling) { using SMemVec = Vec; using OVec = Vec; union IVec { @@ -252,8 +252,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) amax = fmaxf(amax, other_amax); } amax = __shfl_sync(mask, amax, src_lane); + CType scale; // Step 2.4: Compute scale - CType scale = ComputeScale(amax, epsilon); + if (pow_2_scaling) { + scale = ComputeScale(amax, epsilon); + } else { + scale = ComputeScale(amax, epsilon); + } // Step 2.5: Write scale_inv bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -263,15 +268,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) CType scale_inv = 1.0 / scale; size_t row_idx = static_cast(blockIdx.y) * kTileDim + r_s; size_t col_idx = static_cast(blockIdx.x); - if constexpr (kPermuteScale) { - size_t p_row = row_idx / kTileDim; - size_t p_col = col_idx; - size_t p_dep = row_idx % kTileDim; - size_t p_2d_stride = kTileDim * scale_stride_y; - tile_scales_inv_c[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; - } else { - tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; - } + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; } // Step 2.6: Quantize OVec output_vec; @@ -301,7 +298,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } // Step 3: Transpose, cast and store to output_t - if constexpr (kReturnTranspose) { + if (return_transpose) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); @@ -349,7 +346,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } amax = __shfl_sync(mask, amax, src_lane); // Step 3.4: Compute scale - CType scale = ComputeScale(amax, epsilon); + CType scale; + if (pow_2_scaling) { + scale = ComputeScale(amax, epsilon); + } else { + scale = ComputeScale(amax, epsilon); + } // Step 3.5: Write scale_inv_t bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -359,15 +361,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) CType scale_inv = 1.0 / scale; size_t row_idx = static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx; size_t col_idx = static_cast(blockIdx.y); - if constexpr (kPermuteScale) { - size_t p_row = row_idx / kTileDim; - size_t p_col = col_idx; - size_t p_dep = row_idx % kTileDim; - size_t p_2d_stride = kTileDim * scale_t_stride_y; - tile_scales_inv_t[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; - } else { - tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; - } + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; } // Step 3.6: Quantize OVec output_vec; @@ -422,32 +416,16 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor } // Options for scale layout of cuBLAS GEMM kernel. - constexpr bool kPermuteScale = false; - bool permute_scale = false; - bool transpose_scales = true; NVTE_CHECK(input.shape.size() == output.shape.size(), "Input and output must have the same shape."); - NVTE_CHECK((!transpose_scales || !permute_scale), - "Permute scale and transpose scales are mutually exclusive flags."); size_t scale_stride_x = 0; size_t scale_stride_y = 0; - if (permute_scale) { - NVTE_CHECK(scale_inv.shape.size() == 3, "scale_inv must have 3 dimensions."); - size_t scale_k = scale_inv.shape[1]; - NVTE_CHECK(scale_inv.shape[2] == kTileDim, "Scale inner dimension must be kTileDim."); - scale_stride_x = 1; - scale_stride_y = scale_k; - } else { - NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2 when not permuting scale."); - size_t scale_k = scale_inv.shape[1]; - scale_stride_x = 1; - scale_stride_y = scale_k; - if (transpose_scales) { - std::swap(scale_stride_x, scale_stride_y); - } - } + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = scale_k; + scale_stride_y = 1; size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; @@ -464,20 +442,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); - if (permute_scale) { - NVTE_CHECK(scale_inv_t.shape.size() == 3, "Scale_t dimension must be 3."); - scale_t_stride_x = 1; - scale_t_stride_y = scale_inv_t.shape[1]; - NVTE_CHECK(scale_inv_t.shape[2] == kTileDim, "Scale_t inner dimension must be kTileDim."); - } else { - NVTE_CHECK(scale_inv_t.shape.size() == 2, - "Scale_t dimension must be 2 when not permuting scale."); - scale_t_stride_x = 1; - scale_t_stride_y = scale_inv_t.shape[1]; - if (transpose_scales) { - std::swap(scale_t_stride_x, scale_t_stride_y); - } - } + NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); + scale_t_stride_x = scale_inv_t.shape[1]; + scale_t_stride_y = 1; } const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); @@ -494,38 +461,26 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; TRANSFORMER_ENGINE_SWITCH_CONDITION( - return_transpose, kReturnTranspose, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - pow2_scale, kPow2Scale, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - full_tile, kAligned, - - size_t smem_bytes = kSMemSize * sizeof(InputType); - // shared memory must be requested up - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - &block_scaled_1d_cast_transpose_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); - } block_scaled_1d_cast_transpose_kernel - <<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, - epsilon);) // kAligned - ) // kPow2Scale - ) // kReturnTranspose - ) // OutputType - ) // InputType + full_tile, kAligned, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + // shared memory must be requested up + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + &block_scaled_1d_cast_transpose_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); + } block_scaled_1d_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, + pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } From 8ad710705cb078fdfe9130c6763a2abd8853f44f Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 17 Mar 2025 10:20:31 -0700 Subject: [PATCH 15/38] Merge conflict from rebase. Signed-off-by: Keith Wyss --- tests/cpp/test_common.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4ecd51dca3..006b4bfd09 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -243,10 +243,6 @@ Tensor::Tensor(const std::string& name, NVTEShape normalized_shape = convertShape(normalized_shape_v); NVTEShape columnwise_shape{nullptr, 0}; - size_t block_scaling_dim = 0; - if (q_opts != nullptr) { - block_scaling_dim = q_opts->block_scaling_dim; - } std::vector columnwise_shape_vec; if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { // Transpose when tensor scaling From 2d6a3795da87ace45729cf1b75ed1142ba236bee Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 19 Mar 2025 13:04:24 -0700 Subject: [PATCH 16/38] Add shape implementations for block scaling. nvte_shape was added upstream. Logic added for block scaled fp8. Signed-off-by: Keith Wyss --- transformer_engine/common/common.h | 22 +++++++++++++ .../transformer_engine/transformer_engine.h | 17 ++++++++++ .../common/transformer_engine.cpp | 32 +++++++++++++++++++ .../csrc/extensions/type_converters.cpp | 26 +++++++++++++-- 4 files changed, 94 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 47a84ad068..7dee033c8c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -185,6 +185,28 @@ struct Tensor { return data.shape; } break; + case NVTE_BLOCK_SCALING_1D: + case NVTE_BLOCK_SCALING_2D: { + if (!has_data() && has_columnwise_data()) { + std::vector shape; + size_t ndim = columnwise_data.shape.size(); + shape.reserve(ndim); + for (size_t i = 0; i + 1 < ndim; ++i) { + shape.push_back(columnwise_data.shape[i + 1]); + } + if (ndim > 0) { + shape.push_back(columnwise_data.shape[0]); + } + return shape; + } else { + // NOTE: We may have removed the data pointer from + // data by setting usage. In that case, we return + // the non-null shape. It is our best guess at the most + // recent shape. + return data.shape; + } + break; + } default: NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); return {}; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 0e2b6cfcff..23a6d8226a 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -110,6 +110,19 @@ typedef void *NVTETensor; */ NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); +/*! \brief Create a new TE tensor. + * + * Create a new TE tensor. Before use its parameters need to be set. + * TE tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] scaling_mode Scaling mode of the tensor. + * \param[in] initial_shape Shape to initialize tensor with. + * + * \return A new TE tensor. + */ +NVTETensor nvte_create_tensor_with_shape(NVTEScalingMode scaling_mode, NVTEShape initial_shape); + /*! \brief Destroy a TE tensor. * * Since the TE tensor does not own memory, the underlying @@ -469,6 +482,10 @@ class TensorWrapper { explicit TensorWrapper(const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) : tensor_(nvte_create_tensor(scaling_mode)) {} + TensorWrapper(const NVTEScalingMode scaling_mode, const std::vector &rowwise_shape) + : tensor_(nvte_create_tensor_with_shape( + scaling_mode, NVTEShape{rowwise_shape.data(), rowwise_shape.size()})) {} + /*! \brief TensorWrapper destructor. */ ~TensorWrapper() { nvte_destroy_tensor(tensor_); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 2fa9fc3aba..5b4fbeb258 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -199,6 +199,16 @@ NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { return ret; } +NVTETensor nvte_create_tensor_with_shape(NVTEScalingMode scaling_mode, NVTEShape initial_shape) { + transformer_engine::Tensor *ret = new transformer_engine::Tensor; + ret->scaling_mode = scaling_mode; + ret->data.shape.reserve(initial_shape.ndim); + for (size_t i = 0; i < initial_shape.ndim; ++i) { + ret->data.shape.push_back(initial_shape.data[i]); + } + return ret; +} + void nvte_destroy_tensor(NVTETensor tensor) { if (tensor == nullptr) return; auto *t = reinterpret_cast(tensor); @@ -252,6 +262,28 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } break; } + case NVTE_BLOCK_SCALING_1D: + case NVTE_BLOCK_SCALING_2D: { + if (!t.has_data() && t.has_columnwise_data()) { + std::vector shape; + ret.ndim = t.columnwise_data.shape.size(); + shape.reserve(ret.ndim); + for (int i = 0; i + 1 < static_cast(ret.ndim); ++i) { + shape.push_back(t.columnwise_data.shape[i + 1]); + } + if (ret.ndim > 0) { + shape.push_back(t.columnwise_data.shape[0]); + } + NVTE_CHECK(t.data.shape == shape, + "Must return shape allocated on tensor. " + "data shape expected to match derivation from columnwise."); + ret.data = t.data.shape.data(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } default: NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", transformer_engine::to_string(t.scaling_mode), "\""); diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index 440e819ae7..6a32d4cec0 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -87,17 +87,37 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { const DType dtype = tensor.attr("_fp8_dtype").cast(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); - auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + std::vector initial_rowwise_shape; + if (rowwise_usage) { + initial_rowwise_shape = getTensorShape(tensor.attr("_rowwise_data").cast()); + } else if (columnwise_usage) { + std::vector columnwise_shape = + getTensorShape(tensor.attr("_columnwise_data").cast()); + + // Even though we don't have rowwise data, we want to store the + // rowwise shape so that nvte_tensor_shape can return an allocated + // vector. + initial_rowwise_shape.reserve(columnwise_shape.size()); + for (size_t i = 0; i + 1 < columnwise_shape.size(); ++i) { + initial_rowwise_shape.push_back(columnwise_shape[i + 1]); + } + if (columnwise_shape.size() > 0) { + initial_rowwise_shape.push_back(columnwise_shape[0]); + } + } + + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, + initial_rowwise_shape); + if (rowwise_usage) { const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto &shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, initial_rowwise_shape); const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); From 2306611cae4abdf07453d0bec662b3bfc517fef4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 14:15:33 -0700 Subject: [PATCH 17/38] Move benchmark to te_playground Signed-off-by: Keith Wyss --- .../cast/benchmark_quantize_transpose.py | 94 ------------------- 1 file changed, 94 deletions(-) delete mode 100644 benchmarks/experimental/cast/benchmark_quantize_transpose.py diff --git a/benchmarks/experimental/cast/benchmark_quantize_transpose.py b/benchmarks/experimental/cast/benchmark_quantize_transpose.py deleted file mode 100644 index e0474a50f7..0000000000 --- a/benchmarks/experimental/cast/benchmark_quantize_transpose.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse -import logging -import os -import pathlib -import sys - -import pandas as pd -import torch -import torch.utils.benchmark as benchmark -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -import transformer_engine_torch as tex - - -def run_kernel( - shape, - is_1d: bool, - return_transpose: bool, - input_dtype=torch.bfloat16, - quant_dtype=tex.DType.kFloat8E4M3, -): - # Generate random input data - M, K = shape - src = torch.randn([M, K], dtype=input_dtype, device="cuda") - - quantizer = Float8BlockQuantizer( - fp8_dtype=quant_dtype, - rowwise=True, - columnwise=return_transpose, - block_scaling_dim=1 if is_1d else 2, - ) - dst = quantizer.make_empty(shape, dtype=input_dtype, device="cuda") - - kernel_func = tex.quantize - stmt = "kernel_func(src, quantizer, dst)" - globals_dict = { - "kernel_func": kernel_func, - "quantizer": quantizer, - "src": src, - "dst": dst, - } - measurement = benchmark.Timer( - stmt=stmt, - globals=globals_dict, - num_threads=1, - setup="", - ).adaptive_autorange(threshold=0.1, min_run_time=1.0, max_run_time=5.0) - logging.info(f"Measurement: {measurement}") - timing_us = measurement.median * 1e6 - return timing_us - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--output_dir", - type=str, - default="benchmark_output/", - help="output path for report", - ) - args = parser.parse_args() - - shapes = [ - (256, 1024), - # (256, 1020), - # 8B model shape - (4096, 3072), - (4096, 4096), - (4096, 5440), - # 15B model shape - (16384, 1024), - (16384, 3072), - (16384, 6144), - (16384, 12288), - (16384, 24576), - ] - - dim_1d_opts = [True, False] - return_transpose_opts = [True, False] - - data = [] - for dim_1d_opt in dim_1d_opts: - for return_transpose in return_transpose_opts: - for shape in shapes: - print(f"Running 1D={dim_1d_opt} with shape {shape}") - timing_us = run_kernel(shape, dim_1d_opt, return_transpose) - data.append([dim_1d_opt, return_transpose, shape, timing_us]) - - df = pd.DataFrame(data=data, columns=["is_1d_kernel", "return_transpose", "shape", "timing_us"]) - logging.info(df) - pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) - report_file = pathlib.Path(args.output_dir) / f"{pathlib.Path(__file__).stem}_report.csv" - df.to_csv(report_file, index=False) - print(df) - logging.info(f"Report saved to {report_file}") From fff18183fb3892a935d8a64c692fa51fe9e97ef6 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 14:36:37 -0700 Subject: [PATCH 18/38] Remove amax_epsilon and pow_2_scales from tensor. Hardcodes the default values. Signed-off-by: Keith Wyss --- .../cpp/operator/test_cast_float8blockwise.cu | 6 +- .../test_float8_blockwise_scaling_exact.py | 8 +-- transformer_engine/common/common.h | 27 +------ .../transformer_engine/transformer_engine.h | 70 ------------------- .../common/transformer_engine.cpp | 30 -------- .../common/util/cast_kernels.cuh | 24 +++---- .../pytorch/csrc/extensions/quantizer.cpp | 14 ++-- .../pytorch/tensor/float8_blockwise_tensor.py | 2 +- 8 files changed, 28 insertions(+), 153 deletions(-) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index ce3175aafd..cc27f72769 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -429,8 +429,6 @@ std::vector Activation_types = { std::vector amax_epsilons = { 0.0f, - // Set large epsilon to get observable behavior. - 0.1f, }; } // namespace @@ -601,7 +599,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); @@ -625,7 +623,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), - ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 3c9c857fd3..a0e11a7af2 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -88,11 +88,11 @@ def initialize_for_many_scales( ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0, 1e-12], ids=["eps_0", "eps_1e-12"]) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) @pytest.mark.parametrize( "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] ) -@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) @pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, @@ -203,8 +203,8 @@ def test_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("eps", [0, math.pow(2, -125)], ids=["eps_0", "eps_small"]) -@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) @pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) def test_quantization_block_tiling_extrema_versus_reference( diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 7dee033c8c..2d4629f0b1 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -101,9 +101,6 @@ struct Tensor { NVTEScalingMode scaling_mode; - float amax_epsilon; - bool force_pow_2_scales; - Tensor() : data(), columnwise_data(), @@ -111,9 +108,7 @@ struct Tensor { scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), - scaling_mode(NVTE_DELAYED_TENSOR_SCALING), - amax_epsilon(0.0), - force_pow_2_scales(false) {} + scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} int numel() const { size_t acc = 1; @@ -130,26 +125,6 @@ struct Tensor { return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; } - bool supports_force_pow_2_scales_qopt() const noexcept { - switch (scaling_mode) { - case NVTE_BLOCK_SCALING_2D: - case NVTE_BLOCK_SCALING_1D: - return true; - default: - return false; - } - } - - bool supports_amax_epsilon_qopt() const noexcept { - switch (scaling_mode) { - case NVTE_BLOCK_SCALING_2D: - case NVTE_BLOCK_SCALING_1D: - return true; - default: - return false; - } - } - DType dtype() const { if (has_data()) return data.dtype; if (has_columnwise_data()) return columnwise_data.dtype; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 23a6d8226a..85a387c450 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -254,52 +254,6 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, const NVTEBasicTensor *param); -/*! \brief Set a quantization option for whether to force power of 2 scales. - * - * \param[in/out] tensor Tensor. - * \param[in] zero_if_false Whether to force power of 2 scales. - * - * \return zero if the tensor supports this option and it was set. non-zero if - * call had no effect. - */ -int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false); - -/*! \brief Set a quantization option for epsilon to set floor of amax. - * - * \param[in/out] tensor Tensor. - * \param[in] amax_epsilon Epsilon to use for amax calculation. - * - * \return zero if the tensor supports this option and it was set. non-zero if - * call had no effect. - */ -int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon); - -/*! \brief Get a quantization option for whether to force power of 2 scales. - * - * \param[in] tensor Tensor. - * - * \return zero if the tensor will not force power of 2 scales or if the - * setting is irrelevant. non-zero if the flag is configured. - */ -int nvte_get_qopt_force_pow_2_scales(NVTETensor tensor); - -/*! \brief Get a quantization option for amax epsilon. - * - * \param[in] tensor Tensor. - * - * \return amax_epsilon value or zero if not applicable. - */ -float nvte_get_qopt_amax_epsilon(const NVTETensor tensor); - -/*! \brief Get the number of dimensions in the quantization blocks. - * - * \param[in] tensor Tensor. - * - * \return zero if the quantization does not support the block_scaling_dim - * option or the block_scaling_dim configured. - */ -int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor); - /*! \brief Get a value of the parameter of the tensor. * * \param[in] tensor Tensor. @@ -729,18 +683,6 @@ class TensorWrapper { void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } - int set_qopt_force_pow_2_scales(bool flag) { - return nvte_set_qopt_force_pow_2_scales(tensor_, flag ? 1 : 0); - } - - int set_qopt_amax_epsilon(float eps) { return nvte_set_qopt_amax_epsilon(tensor_, eps); } - - bool get_qopt_force_pow_2_scales() const { - return nvte_get_qopt_force_pow_2_scales(tensor_) != 0; - } - - float get_qopt_amax_epsilon() const { return nvte_get_qopt_amax_epsilon(tensor_); } - static constexpr size_t defaultData = 1; static constexpr NVTEShape defaultShape = {&defaultData, 1}; @@ -788,18 +730,6 @@ class QuantizationConfigWrapper { */ operator NVTEQuantizationConfig() const noexcept { return config_; } - /*! \brief Set whether to force power of 2 scales */ - void set_force_pow_2_scales(bool force_pow_2_scales) { - nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales, - &force_pow_2_scales, sizeof(bool)); - } - - /*! \brief Set small value to add to amax */ - void set_amax_epsilon(float amax_epsilon) { - nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEpsilon, - &amax_epsilon, sizeof(float)); - } - private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 5b4fbeb258..c8ea95c7b6 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -534,33 +534,3 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { delete reinterpret_cast(config); } } - -int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false) { - auto &t = *reinterpret_cast(tensor); - if (t.supports_force_pow_2_scales_qopt()) { - t.force_pow_2_scales = zero_if_false != 0; - return 0; - } else { - return 1; - } -} - -int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon) { - auto &t = *reinterpret_cast(tensor); - if (t.supports_amax_epsilon_qopt()) { - t.amax_epsilon = amax_epsilon; - return 0; - } else { - return 1; - } -} - -int nvte_get_qopt_force_pow_2_scales(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); - return t.force_pow_2_scales ? 1 : 0; -} - -float nvte_get_qopt_amax_epsilon(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); - return t.amax_epsilon; -} diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 458f2ff217..412a6f6ef0 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1266,24 +1266,24 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - quantize_transpose_square_blockwise(input_tensor->data, output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, output_tensor->data, - output_tensor->columnwise_data, - /*epsilon=*/output_tensor->amax_epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), - output_tensor->force_pow_2_scales, stream); + constexpr bool force_pow_2_scales = true; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/0.0, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } case NVTE_BLOCK_SCALING_1D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, output_tensor->data, - output_tensor->columnwise_data, - /*epsilon=*/output_tensor->amax_epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), - output_tensor->force_pow_2_scales, stream); + constexpr bool force_pow_2_scales = true; + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/0.0, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); break; } default: diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 5826427574..364cd26ef9 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -256,9 +256,15 @@ std::pair Float8CurrentScalingQuantizer::create_tenso Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); - this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); - this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast(), + "Pending additional parameters to the nvte_quantize API, " + "float8 block quantization requires pow2 scales"); + NVTE_CHECK(quantizer.attr("amax_epsilon").cast() == 0.0, + "Pending additional parameters to the nvte_quantize API, " + "float8 block quantization requires amax_epsilon==0"); + NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, + "Unsupported block scaling dim."); } void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -274,10 +280,6 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const rowwise_data.shape); tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); - - // Set options on TensorWrapper from quantization. - tensor->set_qopt_force_pow_2_scales(force_pow_2_scales); - tensor->set_qopt_amax_epsilon(amax_epsilon); } std::pair Float8BlockQuantizer::create_tensor( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 56c8a934ed..701fad8c33 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -40,7 +40,7 @@ def __init__( rowwise: bool, columnwise: bool, amax_epsilon: float = 0.0, - force_pow_2_scales: bool = False, + force_pow_2_scales: bool = True, block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) From 4de7aacd1c1ce6282e204458151d06a73789a1da Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 15:04:38 -0700 Subject: [PATCH 19/38] Lint changes. Signed-off-by: Keith Wyss --- .../pytorch/tensor/float8_blockwise_tensor.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 701fad8c33..9e644ccb8a 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -359,13 +359,13 @@ def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # because of __torch_function__ in base class # and torch._C._disabled_torch_function_impl return _ViewFunc.forward(None, self, shape) - return super.view(self, *shape) + return super().view(self, *shape) def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring if not self.requires_grad: return _ReshapeFunc.forward(None, self, shape) - return super.reshape(self, *shape) + return super().reshape(self, *shape) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -477,10 +477,11 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst.requires_grad_(requires_grad=src.requires_grad) # Just copy FP8 data if other tensor is Float8BlockwiseQTensor - if ( - isinstance(tensor, Float8BlockwiseQTensor) + compatible_layout = (isinstance(tensor, Float8BlockwiseQTensor) and self.size() == tensor.size() - and self.stride() == tensor.stride() + and self.stride() == tensor.stride()) + if ( + compatible_layout and self.storage_offset() == tensor.storage_offset() and self.dtype == tensor.dtype and self.layout == tensor.layout @@ -488,7 +489,7 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): ): _set_from_tensor(self, tensor) return - elif isinstance(tensor, Float8BlockwiseQTensor): + if isinstance(tensor, Float8BlockwiseQTensor): assert tensor._quantizer is not None, "Can't quantize without a quantizer" quantizer = tensor._quantizer else: @@ -555,7 +556,6 @@ def forward( # pylint: disable=missing-function-docstring # Return input tensor if shape is not provided - shape_arg = shape if ctx is not None: ctx.shape = tensor.shape if shape is None: From e6316e95736c87cd4a1dd93f1f2bb1c5d550466a Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 15:16:04 -0700 Subject: [PATCH 20/38] Fixup MR changes that broke. Signed-off-by: Keith Wyss --- .../include/transformer_engine/transformer_engine.h | 12 ++++++++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 6 ++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 85a387c450..18244b8ece 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -730,6 +730,18 @@ class QuantizationConfigWrapper { */ operator NVTEQuantizationConfig() const noexcept { return config_; } + /*! \brief Set whether to force power of 2 scales */ + void set_force_pow_2_scales(bool force_pow_2_scales) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales, + &force_pow_2_scales, sizeof(bool)); + } + + /*! \brief Set small value to add to amax */ + void set_amax_epsilon(float amax_epsilon) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEpsilon, + &amax_epsilon, sizeof(float)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 9e644ccb8a..2de587fbff 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -477,9 +477,11 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst.requires_grad_(requires_grad=src.requires_grad) # Just copy FP8 data if other tensor is Float8BlockwiseQTensor - compatible_layout = (isinstance(tensor, Float8BlockwiseQTensor) + compatible_layout = ( + isinstance(tensor, Float8BlockwiseQTensor) and self.size() == tensor.size() - and self.stride() == tensor.stride()) + and self.stride() == tensor.stride() + ) if ( compatible_layout and self.storage_offset() == tensor.storage_offset() From fd951d8ba2be69bd20cf721e624bbd618ca68ec2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 16:20:12 -0700 Subject: [PATCH 21/38] Safer ifdef in kernel. Signed-off-by: Keith Wyss --- .../common/transpose/quantize_transpose_square_blockwise.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 5d2166a81d..8ef6267018 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -16,7 +16,7 @@ #include "common/utils.cuh" #include "compute_scale.cuh" -#if (!defined(__CUDA_MINIMUM_ARCH__)) || \ +#if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \ (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) #define TMA_HW_SUPPORTED #endif From cf0021afbb29cbdec2c0b31bf48b98a090af52fc Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 17:35:45 -0700 Subject: [PATCH 22/38] Documentation prose. Signed-off-by: Keith Wyss --- .../common/include/transformer_engine/cast.h | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 3c0d24df78..b46ca51256 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -18,21 +18,26 @@ extern "C" { #endif /* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) - * The implementation is per the microscaling format MXFP8 defined by the OCP specification: + * Supported formats are: + * + * 1) MXFP8 scaling (for 10.0 or newer) + * + * The MXFP8 implementation is per the microscaling format MXFP8 defined by the OCP specification: * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * - * Supported modes of scaling (live scaling): - * 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: + * + * Supported modes of MXFP8 scaling (live scaling): + * a) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: * - the scaled output tensor * - the corresponding scaling factors * The scaling factors are computed for blocks of the shape [1,32] * (i.e., each scaling factor spans 32 contiguous elements along rows). * - * 2) Columwise scaling (along the dim=1) computes one set of the output data. + * b) Columwise scaling (along the dim=1) computes one set of the output data. * The scaling factors are computed for blocks of the shape [32,1] * (i.e., each scaling factor spans 32 contiguous elements along columns). * - * 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) + * c) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) * computes two sets of the output data: both 1) and 2). * * The shape of the MX block must be specified in the 'output' argument, @@ -40,6 +45,37 @@ extern "C" { * * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter * of the output tensor should be set to 0. + * + * Also supported are + * + * 2) per-tensor scaling modes that quantize the entire tensor + * using a single scaling factor. The absolute maximum value of the tensor should + * be precalculated either online (current scaling) or based on a tensor history + * (delayed scaling). The calls to nvte_quantize scale based on that data value. + * + * + * 3) FP8 block scaling formats NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D + * for compute capability of at least 9.0. These modes quantize the tensor by blocks + * of size 1x128 (with columnwise mode of 128x1) and 128x128 respectively. + * + * The supported modes are: + * a) Rowwise scaling (along the dim=0) yields output data + * - the scaled output tensor in fp8 coefficients with identical shape to the + * input tensor. + * - Scale factors which are computed for either 1D 1x128 or 2D 128x128 blocks. + * b) Columnwise scaling (along the dim=1) yields output data + * - the scaled output tensor in fp8 coefficients with a shape equivalent to + * the transpose of the input tensor. + * - Scale factors which are calculated for either 1D 128x1 or 2D 128x128 blocks + * of the input tensor. + * c) Both: In which all four tensors of the above are calculated. + * + * This quantization mode includes both the calculation of the scaling factors + * per-tile and quantization of the row and/or columnwise tiles. No precalculated + * absolute max is required. The scaling factors are also rounded to powers of 2, + * such that even if they are stored in fp32 on compute capability 9.0, they are + * numerically compatible with e8m0 scales such that the quantization is portable + * to hardware supporting MXFP8 without numerical disruption. */ /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8. From 32cc5b4cbb26c15a6489d7f51c59e639a4ff413f Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 18:52:49 -0700 Subject: [PATCH 23/38] Reuse compute_scale function from Current Scaling. Signed-off-by: Keith Wyss --- .../common/recipe/current_scaling.cu | 3 +- .../common/recipe/recipe_common.cuh | 17 +++---- .../common/transpose/compute_scale.cuh | 49 ++----------------- 3 files changed, 12 insertions(+), 57 deletions(-) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index e53ab18360..197863569e 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -152,7 +152,8 @@ namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, const float epsilon) { - *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon); + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, + std::numeric_limits::max()); } } // namespace diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh index c789a9b497..9554211cb5 100644 --- a/transformer_engine/common/recipe/recipe_common.cuh +++ b/transformer_engine/common/recipe/recipe_common.cuh @@ -7,19 +7,19 @@ #ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ -#include - namespace transformer_engine { __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, - bool force_pow_2_scales, float epsilon) { + bool force_pow_2_scales, float epsilon, + float value_for_inf) { + // NOTE: NAN amax evaluates false for <, handled further down. if (amax < epsilon) { amax = epsilon; } float scale = 1.f; - if (isinf(amax) || amax == 0.f) { + if (isinf(amax) || amax == 0.f || isnan(amax)) { return scale; } @@ -32,18 +32,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f // the scale is not representable in FP32. if (isinf(scale)) { // use fp32 max to represent the scale - scale = std::numeric_limits::max(); - } - - if (isnan(scale)) { - scale = 1.f; + scale = inf_value; } - if (force_pow_2_scales) { uint32_t scale_bits = *reinterpret_cast(&scale); scale_bits &= 0xFF800000; // If the exponent was zero, we have a logic error. - __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0 || scale == 0.0); __builtin_assume(scale_bits != 0x80000000); scale = *reinterpret_cast(&scale_bits); } diff --git a/transformer_engine/common/transpose/compute_scale.cuh b/transformer_engine/common/transpose/compute_scale.cuh index 0f17829fb2..8013d452a6 100644 --- a/transformer_engine/common/transpose/compute_scale.cuh +++ b/transformer_engine/common/transpose/compute_scale.cuh @@ -14,6 +14,8 @@ #include #include +#include "../recipe/recipe_common.cuh" + namespace transformer_engine { // Type trait for extreme values of fp8 types. @@ -88,51 +90,8 @@ struct HighPrecisionFloatScaleLimitsTrait { template __device__ __forceinline__ float ComputeScale(const float amax, const float eps) { constexpr float fp8_max = F8LimitsTrait::max; - - // Clamping amax to avoid division by small numbers - float amax_mod = fmaxf(amax, eps); - - // Handle overflow cases for non-clamped amax (eps is 0 or very small) - if (amax_mod == 0.f) { - // If amax is 0, return 1 - return 1.f; - } - // Compute scale factor - float scale = fp8_max / amax_mod; - - if (isinf(scale)) { - // If scale is infinity, return max value of IType - return HighPrecisionFloatScaleLimitsTrait::max; - } - if (scale == 0.0) { - // Case that amax is "inf". The frexp, ldexp logic changes 0.0 scales. - // Return 0.0 for 0.0 scale here is consistent with non-Power2Scaling model. - // quantization will remove signal from the tensor, - // this is bad for the model, but define pow2Scale behavior - // as returning 0.0 scale. amax calculation can - // improve the situation to avoid this by taking largest finite. - return scale; - } - if constexpr (Power2Scaling) { - // NOTE: using bit fiddling rather than pow2, exp to - // be exact. - // - // inf scales already early returned, as did nan scales. - // The cases to consider here are normals, zero, and subnormals. - // zero is not possible with current math as - // 448.0 / float_max == 1.31655e-36, which is the smallest - // possible scale given current dtypes. It is still in the normal - // fp32 range with an exponent of -120, so subnormals are also - // not possible. To handle normals, we can simply mask off the - // mantissa. - uint32_t scale_bits = *reinterpret_cast(&scale); - scale_bits &= 0xFF800000; - // If the exponent was zero, we have a logic error. - __builtin_assume(scale_bits != 0); - __builtin_assume(scale_bits != 0x80000000); - scale = *reinterpret_cast(&scale_bits); - } - return scale; + constexpr float value_for_inf = HighPrecisionFloatScaleLimitsTrait::max; + return compute_scale_from_amax(amax, fp8_max, Power2Scaling, eps, value_for_inf); } } // namespace transformer_engine From d23ae3b885308e9bbe00c4283d6bd522e65a7d5c Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 10:32:39 -0700 Subject: [PATCH 24/38] Bugfix on inf_value scale refactor. Signed-off-by: Keith Wyss --- transformer_engine/common/recipe/recipe_common.cuh | 2 +- .../extensions/multi_tensor/multi_tensor_compute_scale.cu | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh index 9554211cb5..f7c6e12fbf 100644 --- a/transformer_engine/common/recipe/recipe_common.cuh +++ b/transformer_engine/common/recipe/recipe_common.cuh @@ -32,7 +32,7 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f // the scale is not representable in FP32. if (isinf(scale)) { // use fp32 max to represent the scale - scale = inf_value; + scale = value_for_inf; } if (force_pow_2_scales) { uint32_t scale_bits = *reinterpret_cast(&scale); diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu index d262767958..0770e63015 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu @@ -12,6 +12,8 @@ // #include #include + +#include // Stringstream is a big hammer, but I want to rely on operator<< for dtype. #include @@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor { n -= chunk_idx * chunk_size; for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { - float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8, - force_pow_2_scales, epsilon); + float scale_val = transformer_engine::compute_scale_from_amax( + amax[i_start], max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); scale[i_start] = scale_val; transformer_engine::reciprocal(scale_inv + i_start, scale_val); } From 9dafe5edd999a0bb7b5b8f51e69521ab6b14b95c Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 10:49:25 -0700 Subject: [PATCH 25/38] Remove qopt calls from test. Signed-off-by: Keith Wyss --- tests/cpp/test_common.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 006b4bfd09..071c2186e0 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -329,8 +329,8 @@ Tensor::Tensor(const std::string& name, } } if (q_opts != nullptr) { - tensor_.set_qopt_force_pow_2_scales(q_opts->force_pow_2_scales); - tensor_.set_qopt_amax_epsilon(q_opts->amax_epsilon); + NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation."); + NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation."); } } } From 29d22ca2e02de67c100882bb2e311cd31ce703c8 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 11:41:33 -0700 Subject: [PATCH 26/38] Update pytest list. Signed-off-by: Keith Wyss --- qa/L0_pytorch_unittest/test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 8d38fa59df..21eaededc4 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -30,6 +30,8 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" From 279f7917e23f75ebfa4f29ddbc471e392e9d31f4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 12:17:37 -0700 Subject: [PATCH 27/38] Add copyright to reference scale calc. Signed-off-by: Keith Wyss --- tests/pytorch/references/quantize_scale_calc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pytorch/references/quantize_scale_calc.py b/tests/pytorch/references/quantize_scale_calc.py index bd1cb43356..f36ddca3b2 100644 --- a/tests/pytorch/references/quantize_scale_calc.py +++ b/tests/pytorch/references/quantize_scale_calc.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + from typing import Tuple import torch From fff1c6b56c58fae7a579bd5fffb85fadf183a190 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 13:31:45 -0700 Subject: [PATCH 28/38] Use ptx.cuh functions instead of cde. Signed-off-by: Keith Wyss --- .../quantize_transpose_square_blockwise.cu | 19 +++---- transformer_engine/common/util/ptx.cuh | 55 ++++++++++--------- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 8ef6267018..bda477b84c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -13,6 +13,7 @@ #include #include "common/common.h" +#include "common/util/ptx.cuh" #include "common/utils.cuh" #include "compute_scale.cuh" @@ -24,11 +25,6 @@ namespace transformer_engine { namespace { -#ifdef TMA_HW_SUPPORTED -using barrier = cuda::barrier; -namespace cde = cuda::device::experimental; -#endif - // const values configuration constexpr size_t kThreadsPerWarp = 32; @@ -214,21 +210,22 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } // Wait for shared memory writes to be visible to TMA engine. - cde::fence_proxy_async_shared_cta(); + ptx::fence_proxy_async_shared_cta(); __syncthreads(); // After syncthreads, writes by all threads are visible to TMA engine. // Step 5: store transpose output // Initiate TMA transfer to copy shared memory to global memory if (threadIdx.x == 0) { - cde::cp_async_bulk_tensor_2d_shared_to_global( - &tensor_map_output_t, tile_id_y * BLOCK_TILE_DIM, tile_id_x * BLOCK_TILE_DIM, - block_tile_trans_shared_otype_ptr); + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), tile_id_y * BLOCK_TILE_DIM, + tile_id_x * BLOCK_TILE_DIM, + reinterpret_cast(block_tile_trans_shared_otype_ptr)); // Wait for TMA transfer to have finished reading shared memory. // Create a "bulk async-group" out of the previous bulk copy operation. - cde::cp_async_bulk_commit_group(); + ptx::cp_async_bulk_commit_group(); // Wait for the group to have completed reading from shared memory. - cde::cp_async_bulk_wait_group_read<0>(); + ptx::cp_async_bulk_wait_group_read<0>(); } #else // Step 4 Alternative (when TMA is not available, skip writing to shared memory) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index a22b930ecd..55bc247f70 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( : "memory"); } +__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile( + "{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, @@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( : "memory"); } -__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { - uint32_t waitComplete; - asm volatile( - "{\n\t .reg .pred P_OUT; \n\t" - "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P_OUT; \n" - "}" - : "=r"(waitComplete) - : "r"(mbar_ptr), "r"(parity) - : "memory"); - return static_cast(waitComplete); -} - -__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { - uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); - while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group -__device__ __forceinline__ void cp_async_bulk_commit_group() { - asm volatile("cp.async.bulk.commit_group;"); -} - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group __device__ __forceinline__ void cp_async_bulk_wait_group() { asm volatile("cp.async.bulk.wait_group 0;"); @@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { asm volatile("cp.async.bulk.wait_group.read 4;"); } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + // Proxy fence (bi-directional): __device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } + __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx From 9284a9ea5ca4a2413171054f60e5c3b527afbba9 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 13:59:42 -0700 Subject: [PATCH 29/38] Update shape logic with allocation and reuse shape. Signed-off-by: Keith Wyss --- transformer_engine/common/common.h | 11 ++++ .../transformer_engine/transformer_engine.h | 17 ------ .../common/transformer_engine.cpp | 58 ++----------------- .../csrc/extensions/type_converters.cpp | 26 +-------- 4 files changed, 20 insertions(+), 92 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 2d4629f0b1..5e8d21a96e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -99,6 +99,12 @@ struct Tensor { SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; + private: + // Used as an allocation for nvte_tensor_shape + // if the shape has to be inferred from columnwise data. + mutable std::vector rowwise_shape_cache; + + public: NVTEScalingMode scaling_mode; Tensor() @@ -188,6 +194,11 @@ struct Tensor { } } + const std::vector &rowwise_shape_ref() const { + rowwise_shape_cache = shape(); + return rowwise_shape_cache; + } + /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 18244b8ece..c539265e62 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -110,19 +110,6 @@ typedef void *NVTETensor; */ NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); -/*! \brief Create a new TE tensor. - * - * Create a new TE tensor. Before use its parameters need to be set. - * TE tensors are just wrappers on top of raw data and do not - * own memory. - * - * \param[in] scaling_mode Scaling mode of the tensor. - * \param[in] initial_shape Shape to initialize tensor with. - * - * \return A new TE tensor. - */ -NVTETensor nvte_create_tensor_with_shape(NVTEScalingMode scaling_mode, NVTEShape initial_shape); - /*! \brief Destroy a TE tensor. * * Since the TE tensor does not own memory, the underlying @@ -436,10 +423,6 @@ class TensorWrapper { explicit TensorWrapper(const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) : tensor_(nvte_create_tensor(scaling_mode)) {} - TensorWrapper(const NVTEScalingMode scaling_mode, const std::vector &rowwise_shape) - : tensor_(nvte_create_tensor_with_shape( - scaling_mode, NVTEShape{rowwise_shape.data(), rowwise_shape.size()})) {} - /*! \brief TensorWrapper destructor. */ ~TensorWrapper() { nvte_destroy_tensor(tensor_); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index c8ea95c7b6..b3d4ca87dc 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -199,16 +199,6 @@ NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { return ret; } -NVTETensor nvte_create_tensor_with_shape(NVTEScalingMode scaling_mode, NVTEShape initial_shape) { - transformer_engine::Tensor *ret = new transformer_engine::Tensor; - ret->scaling_mode = scaling_mode; - ret->data.shape.reserve(initial_shape.ndim); - for (size_t i = 0; i < initial_shape.ndim; ++i) { - ret->data.shape.push_back(initial_shape.data[i]); - } - return ret; -} - void nvte_destroy_tensor(NVTETensor tensor) { if (tensor == nullptr) return; auto *t = reinterpret_cast(tensor); @@ -230,26 +220,12 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { // Determine tensor shape depending on tensor format const auto &t = *reinterpret_cast(tensor); switch (t.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!t.has_data() && t.has_columnwise_data()) { - // We can infer tensor shape if FP8 tensor only has FP8 data - // transpose. However, NVTEShape only contains a pointer and - // cannot store temporary data. We hack around this by caching - // the tensor shape within the empty FP8 data. - auto &shape_cache = const_cast &>(t.data.shape); - shape_cache.clear(); - if (!t.columnwise_data.shape.empty()) { - for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) { - shape_cache.push_back(t.columnwise_data.shape[i]); - } - shape_cache.push_back(t.columnwise_data.shape.front()); - } - ret.data = shape_cache.data(); - ret.ndim = shape_cache.size(); - } else { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - } + case NVTE_DELAYED_TENSOR_SCALING: + case NVTE_BLOCK_SCALING_1D: + case NVTE_BLOCK_SCALING_2D: { + const std::vector &rowwise_shape = t.rowwise_shape_ref(); + ret.data = rowwise_shape.data(); + ret.ndim = rowwise_shape.size(); break; } case NVTE_MXFP8_1D_SCALING: { @@ -262,28 +238,6 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { } break; } - case NVTE_BLOCK_SCALING_1D: - case NVTE_BLOCK_SCALING_2D: { - if (!t.has_data() && t.has_columnwise_data()) { - std::vector shape; - ret.ndim = t.columnwise_data.shape.size(); - shape.reserve(ret.ndim); - for (int i = 0; i + 1 < static_cast(ret.ndim); ++i) { - shape.push_back(t.columnwise_data.shape[i + 1]); - } - if (ret.ndim > 0) { - shape.push_back(t.columnwise_data.shape[0]); - } - NVTE_CHECK(t.data.shape == shape, - "Must return shape allocated on tensor. " - "data shape expected to match derivation from columnwise."); - ret.data = t.data.shape.data(); - } else { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - } - break; - } default: NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", transformer_engine::to_string(t.scaling_mode), "\""); diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index 6a32d4cec0..cb2121a457 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -91,34 +91,14 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); - std::vector initial_rowwise_shape; - if (rowwise_usage) { - initial_rowwise_shape = getTensorShape(tensor.attr("_rowwise_data").cast()); - } else if (columnwise_usage) { - std::vector columnwise_shape = - getTensorShape(tensor.attr("_columnwise_data").cast()); - - // Even though we don't have rowwise data, we want to store the - // rowwise shape so that nvte_tensor_shape can return an allocated - // vector. - initial_rowwise_shape.reserve(columnwise_shape.size()); - for (size_t i = 0; i + 1 < columnwise_shape.size(); ++i) { - initial_rowwise_shape.push_back(columnwise_shape[i + 1]); - } - if (columnwise_shape.size() > 0) { - initial_rowwise_shape.push_back(columnwise_shape[0]); - } - } - - auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, - initial_rowwise_shape); + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); if (rowwise_usage) { const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, initial_rowwise_shape); - + const auto &rowwise_shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); } From b52a44dccffcc991570b795f01e420d5bf7b3663 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 15:50:10 -0700 Subject: [PATCH 30/38] Usage defaults MR feedback. Signed-off-by: Keith Wyss --- .../pytorch/tensor/float8_blockwise_tensor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 2de587fbff..34fcdf6f03 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -301,11 +301,17 @@ def detach(self) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return Float8BlockwiseQTensor.make_like(self) - def update_usage(self, rowwise_usage=True, columnwise_usage=True): + def update_usage( + self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None + ): """ update_usage can be used to clear out one of two possible copies of the data. """ + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None assert ( columnwise_usage or rowwise_usage ), "Must retain some data either columnwise or rowwise" From 18d80bdaf83bb518b6e14629e9ad0dedd968beab Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 09:52:23 -0700 Subject: [PATCH 31/38] Copyright and header guard. Signed-off-by: Keith Wyss --- .../common/transpose/compute_scale.cuh | 8 +++---- .../pytorch/tensor/float8_blockwise_tensor.py | 22 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/transpose/compute_scale.cuh b/transformer_engine/common/transpose/compute_scale.cuh index 8013d452a6..91e02dc537 100644 --- a/transformer_engine/common/transpose/compute_scale.cuh +++ b/transformer_engine/common/transpose/compute_scale.cuh @@ -1,11 +1,11 @@ /************************************************************************* - * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ -#define TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ +#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_COMPUTE_SCALE_CUH_ +#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_COMPUTE_SCALE_CUH_ #include #include @@ -96,4 +96,4 @@ __device__ __forceinline__ float ComputeScale(const float amax, const float eps) } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ +#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_COMPUTE_SCALE_CUH_ diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 34fcdf6f03..6065250daa 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -468,7 +468,6 @@ def _set_data(self, tensor: torch.Tensor) -> None: casts to FP8. """ - # Tensor device new_device = tensor.device if tensor.is_cuda else self.device @@ -479,24 +478,25 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv - if dst.requires_grad != src.requires_grad: - dst.requires_grad_(requires_grad=src.requires_grad) + dst.dtype = src.dtype + + # Check that tensor dimensions match + if ( + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.layout != tensor.layout + ): + raise ValueError("Invalid tensor for updating Float8BlockwiseQTensor data") # Just copy FP8 data if other tensor is Float8BlockwiseQTensor - compatible_layout = ( - isinstance(tensor, Float8BlockwiseQTensor) - and self.size() == tensor.size() - and self.stride() == tensor.stride() - ) if ( - compatible_layout + isinstance(tensor, Float8BlockwiseQTensor) and self.storage_offset() == tensor.storage_offset() - and self.dtype == tensor.dtype - and self.layout == tensor.layout and devices_match(self.device, new_device) ): _set_from_tensor(self, tensor) return + if isinstance(tensor, Float8BlockwiseQTensor): assert tensor._quantizer is not None, "Can't quantize without a quantizer" quantizer = tensor._quantizer From 18f19bbe0f6f960cbdd09695f629d598dd29d70f Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 13:33:49 -0700 Subject: [PATCH 32/38] Updating torch dispatch code. Signed-off-by: Keith Wyss --- .../pytorch/tensor/float8_blockwise_tensor.py | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 6065250daa..1cb89a7d81 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -360,25 +360,35 @@ def clone(self) -> Float8BlockwiseQTensor: def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring - if not self.requires_grad: - # Autograd removes the quantized return type - # because of __torch_function__ in base class - # and torch._C._disabled_torch_function_impl - return _ViewFunc.forward(None, self, shape) - return super().view(self, *shape) + return _ViewFunc.apply(self, shape) def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring - if not self.requires_grad: - return _ReshapeFunc.forward(None, self, shape) - return super().reshape(self, *shape) + return _ReshapeFunc.apply(self, shape) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # View op if func == aten.view.default: - return _ViewFunc.apply(args[0], *args[1:]) + tensor = args[0] + data = tensor._rowwise_data + if data is None: + # Columnwise data only. + super().__torch_dispatch__(func, types, args, kwargs) + orig_size = data.size() + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + if orig_size != out_data.size(): + raise NotImplementedException( + "Changing shape with view not implemented " + " (scales and columnwise data untouched)." + ) + return Float8BlockwiseQTensor.make_like(tensor) # Default case return super().__torch_dispatch__(func, types, args, kwargs) @@ -458,7 +468,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: def _get_data(self) -> Float8BlockwiseQTensor: """Get tensor data property""" - return self.dequantize() + return self @torch.no_grad() def _set_data(self, tensor: torch.Tensor) -> None: From 572c04bd45b164c6924fdc4d96a4059aee9a8e5a Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 13:53:14 -0700 Subject: [PATCH 33/38] Fix exception type. Signed-off-by: Keith Wyss --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 1cb89a7d81..bdb8711675 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -384,7 +384,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): kwargs, ) if orig_size != out_data.size(): - raise NotImplementedException( + raise NotImplementedError( "Changing shape with view not implemented " " (scales and columnwise data untouched)." ) From bac93485d8214893eb102017c2f958221ce57811 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 14:47:45 -0700 Subject: [PATCH 34/38] Use TypeInfo Signed-off-by: Keith Wyss --- transformer_engine/common/common.h | 31 ++++++ .../common/recipe/recipe_common.cuh | 22 +++++ .../common/transpose/compute_scale.cuh | 99 ------------------- .../quantize_transpose_square_blockwise.cu | 16 +-- .../quantize_transpose_vector_blockwise.cu | 14 +-- 5 files changed, 61 insertions(+), 121 deletions(-) delete mode 100644 transformer_engine/common/transpose/compute_scale.cuh diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 5e8d21a96e..b1fe436379 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -280,6 +280,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) #endif #undef TRANSFORMER_ENGINE_TYPE_NAME +template +struct TypeExtrema; + +template <> +struct TypeExtrema { + static constexpr float max = 448.0f; +}; + +template <> +struct TypeExtrema { + static constexpr float max = 57344.0f; +}; + +template <> +struct TypeExtrema { + // Hex float format of 1.(7 bits of 1) * 2 ^ 127 + static constexpr float max = 0x1.FEp127; +}; + +template <> +struct TypeExtrema { + // Hex float format of 1.(10 bits of 1) * 2 ^ 15 + static constexpr float max = 0x1.FFCp15; +}; + +template +struct TypeExtrema { + static constexpr float max = std::numeric_limits::max(); +}; + } // namespace detail template @@ -310,6 +340,7 @@ struct TypeInfo { constexpr static DType dtype = getType(); constexpr static size_t size = sizeof(T); + constexpr static float max_finite_value = detail::TypeExtrema::max; constexpr static const char *name = detail::type_name(); }; diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh index f7c6e12fbf..11f9bc1299 100644 --- a/transformer_engine/common/recipe/recipe_common.cuh +++ b/transformer_engine/common/recipe/recipe_common.cuh @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ +#include "common/common.h" + namespace transformer_engine { __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, @@ -46,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f return scale; } +// Calculate the quantization scale for an individual data element +// given the amax(abs(tile)) value for a given quantization tile. +// +// +// Arguments: +// IType: data type of the tensor being quantized (float or bf16) +// OType: quantized data type (e4m3 or e5m2) +// amax: The evaluation of amax(abs(tile)) for the quantization tile. +// eps: An epsilon used as a floor for amax. +// pow_2_scaling: Whether to force the scale to be a power of 2. +template +__device__ __forceinline__ float compute_scale_from_types(const float amax, const float eps, + const float pow_2_scaling) { + constexpr float fp8_max = TypeInfo::max_finite_value; + // NOTE: We're relying on compute_scale_from_amax to have behavior where it + // clips the mantissa of the max_finite_value if power of 2 scaling applies. + constexpr float value_for_inf = TypeInfo::max_finite_value; + return compute_scale_from_amax(amax, fp8_max, pow_2_scaling, eps, value_for_inf); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ diff --git a/transformer_engine/common/transpose/compute_scale.cuh b/transformer_engine/common/transpose/compute_scale.cuh deleted file mode 100644 index 91e02dc537..0000000000 --- a/transformer_engine/common/transpose/compute_scale.cuh +++ /dev/null @@ -1,99 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_COMPUTE_SCALE_CUH_ -#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_COMPUTE_SCALE_CUH_ - -#include -#include -#include - -#include -#include - -#include "../recipe/recipe_common.cuh" - -namespace transformer_engine { - -// Type trait for extreme values of fp8 types. -// Used in the calculation of scale factors -// as a constexpr lookup from e4m3 or e5m2 to -// the max finite value. -template -struct F8LimitsTrait; - -template <> -struct F8LimitsTrait<__nv_fp8_e4m3> { - static constexpr float max = 448.0f; -}; - -template <> -struct F8LimitsTrait<__nv_fp8_e5m2> { - static constexpr float max = 57344.0f; -}; - -// Type trait to resolve the max finite value -// represented by a input type to quantization. -// Or to represent max representable power of 2 -// finite value. -template -struct HighPrecisionFloatScaleLimitsTrait; - -template <> -struct HighPrecisionFloatScaleLimitsTrait { - static constexpr float max = std::numeric_limits::max(); -}; - -template <> -struct HighPrecisionFloatScaleLimitsTrait { - // Hex float format of 1.0 * 2 ^ 127 - static constexpr float max = 0x1.0p127; -}; - -template <> -struct HighPrecisionFloatScaleLimitsTrait { - // Hex float format of 1.(7 bits of 1) * 2 ^ 127 - static constexpr float max = 0x1.FEp127; -}; - -template <> -struct HighPrecisionFloatScaleLimitsTrait { - // Hex float format of 1.0 * 2 ^ 127 - static constexpr float max = 0x1.0p127; -}; - -template <> -struct HighPrecisionFloatScaleLimitsTrait { - // Hex float format of 1.(10 bits of 1) * 2 ^ 15 - static constexpr float max = 0x1.FFCp15; -}; - -template <> -struct HighPrecisionFloatScaleLimitsTrait { - // Hex float format of 1.0 * 2 ^ 15 - static constexpr float max = 0x1.0p15; -}; - -// Calculate the quantization scale for an individual data element -// given the amax(abs(tile)) value for a given quantization tile. -// -// -// Arguments: -// IType: data type of the tensor being quantized (float or bf16) -// OType: quantized data type (e4m3 or e5m2) -// pow_2_scaling: Whether to force the scale to be a power of 2. -// amax: The evaluation of amax(abs(tile)) for the quantization tile. -// eps: An epsilon used as a floor for amax. -template -__device__ __forceinline__ float ComputeScale(const float amax, const float eps) { - constexpr float fp8_max = F8LimitsTrait::max; - constexpr float value_for_inf = HighPrecisionFloatScaleLimitsTrait::max; - return compute_scale_from_amax(amax, fp8_max, Power2Scaling, eps, value_for_inf); -} - -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_COMPUTE_SCALE_CUH_ diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index bda477b84c..663c61a1cf 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -13,9 +13,9 @@ #include #include "common/common.h" +#include "common/recipe/recipe_common.cuh" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "compute_scale.cuh" #if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \ (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) @@ -149,11 +149,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) __syncthreads(); block_tile_amax = block_tile_amax_shared[0]; - if (pow_2_scaling) { - block_tile_scale = ComputeScale(block_tile_amax, epsilon); - } else { - block_tile_scale = ComputeScale(block_tile_amax, epsilon); - } + block_tile_scale = + compute_scale_from_types(block_tile_amax, epsilon, pow_2_scaling); if (threadIdx.x == 0) { static_assert(std::is_same::value); @@ -375,11 +372,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose __syncthreads(); block_tile_amax = block_tile_amax_shared[0]; - if (pow_2_scaling) { - block_tile_scale = ComputeScale(block_tile_amax, epsilon); - } else { - block_tile_scale = ComputeScale(block_tile_amax, epsilon); - } + block_tile_scale = + compute_scale_from_types(block_tile_amax, epsilon, pow_2_scaling); if (threadIdx.x == 0) { static_assert(std::is_same::value); diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 2304084632..732d97999c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -15,8 +15,8 @@ #include #include "common/common.h" +#include "common/recipe/recipe_common.cuh" #include "common/utils.cuh" -#include "compute_scale.cuh" namespace transformer_engine { namespace { @@ -254,11 +254,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) amax = __shfl_sync(mask, amax, src_lane); CType scale; // Step 2.4: Compute scale - if (pow_2_scaling) { - scale = ComputeScale(amax, epsilon); - } else { - scale = ComputeScale(amax, epsilon); - } + scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); // Step 2.5: Write scale_inv bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -347,11 +343,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) amax = __shfl_sync(mask, amax, src_lane); // Step 3.4: Compute scale CType scale; - if (pow_2_scaling) { - scale = ComputeScale(amax, epsilon); - } else { - scale = ComputeScale(amax, epsilon); - } + scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); // Step 3.5: Write scale_inv_t bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { From 93d2bf5098ee3c19f9e8f88ea017828058184d24 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 18:02:40 -0700 Subject: [PATCH 35/38] MR feedback. Signed-off-by: Keith Wyss --- .../blockwise_quantizer_reference.py | 2 +- tests/pytorch/references/ref_per_tensor_cs.py | 2 +- .../test_float8_blockwise_scaling_exact.py | 2 +- .../common/include/transformer_engine/cast.h | 35 +++++++++---------- .../common/transformer_engine.cpp | 28 +++------------ .../pytorch/csrc/extensions/quantizer.cpp | 15 +++++--- .../pytorch/tensor/float8_blockwise_tensor.py | 1 + 7 files changed, 36 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index 1f85a1bb6b..b98966f514 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -6,7 +6,7 @@ import math import torch from typing import Optional, Protocol, Tuple -from tests.pytorch.references.quantize_scale_calc import scale_from_amax_tensor +from references.quantize_scale_calc import scale_from_amax_tensor @dataclasses.dataclass() diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index 7c0a161b1c..085071c69a 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -6,7 +6,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType_To_Torch -from tests.pytorch.references.quantize_scale_calc import scale_from_amax_tensor +from references.quantize_scale_calc import scale_from_amax_tensor # compute amax and scale diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index a0e11a7af2..e638fe8c5b 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -14,7 +14,7 @@ Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from tests.pytorch.references.blockwise_quantizer_reference import ( +from references.blockwise_quantizer_reference import ( BlockwiseQuantizerReference, QuantizeResult, ) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index b46ca51256..7fa7957fa4 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -17,16 +17,20 @@ extern "C" { #endif -/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) +/* Quantize the tensor + * + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. + * * Supported formats are: * - * 1) MXFP8 scaling (for 10.0 or newer) + * 1) MXFP8 scaling (for compute capability 10.0 or newer) * * The MXFP8 implementation is per the microscaling format MXFP8 defined by the OCP specification: * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * * - * Supported modes of MXFP8 scaling (live scaling): + * Supported modes of MXFP8 scaling (live scaling) for scaling mode NVTE_MXFP8_1D_SCALING * a) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: * - the scaled output tensor * - the corresponding scaling factors @@ -46,12 +50,12 @@ extern "C" { * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter * of the output tensor should be set to 0. * - * Also supported are - * - * 2) per-tensor scaling modes that quantize the entire tensor + * 2) NVTE_DELAYED_TENSOR_SCALING that quantize the entire tensor * using a single scaling factor. The absolute maximum value of the tensor should * be precalculated either online (current scaling) or based on a tensor history * (delayed scaling). The calls to nvte_quantize scale based on that data value. + * Note the NVTE_DELAYED_TENSOR_SCALING NVTEScalingMode is reused for online + * per tensor scaling. * * * 3) FP8 block scaling formats NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D @@ -59,30 +63,25 @@ extern "C" { * of size 1x128 (with columnwise mode of 128x1) and 128x128 respectively. * * The supported modes are: - * a) Rowwise scaling (along the dim=0) yields output data + * a) Rowwise scaling yields output data: * - the scaled output tensor in fp8 coefficients with identical shape to the * input tensor. * - Scale factors which are computed for either 1D 1x128 or 2D 128x128 blocks. - * b) Columnwise scaling (along the dim=1) yields output data + * b) Columnwise scaling yields output data: * - the scaled output tensor in fp8 coefficients with a shape equivalent to * the transpose of the input tensor. * - Scale factors which are calculated for either 1D 128x1 or 2D 128x128 blocks * of the input tensor. - * c) Both: In which all four tensors of the above are calculated. + * c) Both: In which both tensors and both scales are calculated. * * This quantization mode includes both the calculation of the scaling factors * per-tile and quantization of the row and/or columnwise tiles. No precalculated - * absolute max is required. The scaling factors are also rounded to powers of 2, - * such that even if they are stored in fp32 on compute capability 9.0, they are - * numerically compatible with e8m0 scales such that the quantization is portable - * to hardware supporting MXFP8 without numerical disruption. + * absolute max is required. The scaling factors are also rounded to powers of 2. */ /*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8. - * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the MXFP8 block quantization of the specified shape of the block will be used. - * If the scaling mode of the output tensor is set to NVTE_BLOCK_SCALING_1D or NVTE_BLOCK_SCALING_2D, - * blockwise float8 scaling will be used. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. * * \param[in] input Input tensor to be cast. * \param[in,out] output Output FP8/MXFP8/BlockwiseFP8 tensor. @@ -93,7 +92,7 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea /*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. * The type of quantized tensor in the output depends on the scaling mode of the output - * tensor. + * tensor. See file level comments. * * \param[in] input Input tensor to be cast. * \param[in,out] output Output quantized tensor. diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b3d4ca87dc..97df5892b6 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -215,34 +215,14 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { if (tensor == nullptr) { NVTE_ERROR("Invalid tensor"); } - NVTEShape ret; // Determine tensor shape depending on tensor format const auto &t = *reinterpret_cast(tensor); - switch (t.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: - case NVTE_BLOCK_SCALING_1D: - case NVTE_BLOCK_SCALING_2D: { - const std::vector &rowwise_shape = t.rowwise_shape_ref(); - ret.data = rowwise_shape.data(); - ret.ndim = rowwise_shape.size(); - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (!t.has_data() && t.has_columnwise_data()) { - ret.data = t.columnwise_data.shape.data(); - ret.ndim = t.columnwise_data.shape.size(); - } else { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - } - break; - } - default: - NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", - transformer_engine::to_string(t.scaling_mode), "\""); - } + const std::vector &rowwise_shape = t.rowwise_shape_ref(); + NVTEShape ret; + ret.data = rowwise_shape.data(); + ret.ndim = rowwise_shape.size(); return ret; } diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 364cd26ef9..19d8a75a64 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -292,7 +292,7 @@ std::pair Float8BlockQuantizer::create_tensor( numel *= s; } - TensorWrapper tensor((block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); + TensorWrapper tensor(this->get_scaling_mode()); at::TensorOptions opts; at::TensorOptions scale_opts; at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; @@ -318,7 +318,10 @@ std::pair Float8BlockQuantizer::create_tensor( sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv1 = roundup(m_dim, 4); } else { - NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor rowwise."); + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor rowwise." + "Expected 1 or 2. Got ", + block_scaling_dim); } scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts); tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); @@ -329,7 +332,8 @@ std::pair Float8BlockQuantizer::create_tensor( if (columnwise_usage) { std::vector torch_columnwise_shape; std::vector columnwise_shape; - NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape."); + NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", + columnwise_shape, " torch shape: ", torch_columnwise_shape); if (torch_shape.size() > 0) { torch_columnwise_shape.reserve(torch_shape.size()); columnwise_shape.reserve(shape.size()); @@ -349,7 +353,10 @@ std::pair Float8BlockQuantizer::create_tensor( sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv1 = roundup(k_dim, 4); } else { - NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor columnwise."); + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor columnwise." + "Expected 1 or 2. Got ", + block_scaling_dim); } data_colwise = at::empty(torch_columnwise_shape, opts); scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts); diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index bdb8711675..138d1fd29e 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -251,6 +251,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): def __repr__(self, *, tensor_contents=None): return ( f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," + f" is_2D_scaled={self._is_2D_scaled}," f" data={self.dequantize(dtype=self.dtype)})" ) From 15f10071c0676f8a269ff3344fbff8d7aa2796f3 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 4 Apr 2025 02:57:28 +0000 Subject: [PATCH 36/38] Update CS scale update test to use updated ref impl Signed-off-by: Tim Moon --- tests/pytorch/references/ref_per_tensor_cs.py | 4 ---- tests/pytorch/test_multi_tensor.py | 11 ++++++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index 085071c69a..5e803f7ed5 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -66,7 +66,3 @@ def ref_per_tensor_cs_cast( qx_t = _multi_dim_transpose(qx) sx_t = sx return qx, sx, qx_t, sx_t - - -def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): - return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 4dc1ec087f..737b5ff2b0 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -9,7 +9,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.optimizers import MultiTensorApply -from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax +from references.quantize_scale_calc import scale_from_amax_tensor input_size_pairs = [ @@ -224,17 +224,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, @pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) @pytest.mark.parametrize("applier", appliers) @pytest.mark.parametrize("repeat", [1, 55]) -@pytest.mark.parametrize("max_fp8", [448.0, 57344.0]) +@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("pow_2_scales", [False, True]) @pytest.mark.parametrize("epsilon", [0.0, 100.0]) def test_multi_tensor_compute_scale_and_scale_inv( - input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon + input_size_pair, applier, repeat, fp8_dtype, pow_2_scales, epsilon ): sizea, sizeb = input_size_pair device = torch.device("cuda") overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) a = torch.randn([sizea], dtype=torch.float32, device=device).abs() b = torch.randn([sizeb], dtype=torch.float32, device=device).abs() + max_fp8 = torch.finfo(fp8_dtype).max amax_list = [] for i in range(repeat): @@ -253,8 +254,8 @@ def test_multi_tensor_compute_scale_and_scale_inv( ) for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list): - scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax( - amax, max_fp8, epsilon, pow_2_scales + scale_ref, scale_inv_ref, _ = scale_from_amax_tensor( + torch.float32, amax, fp8_dtype, eps=epsilon, pow_2_scales=pow_2_scales ) torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) From 1a34a86070fbf3c31799b99eb21bd65c7842c146 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 4 Apr 2025 04:36:57 +0000 Subject: [PATCH 37/38] Update JAX scaling mode enum Signed-off-by: Tim Moon --- transformer_engine/jax/quantize/scaling_modes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 7aecc34643..805c034334 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -183,8 +183,8 @@ class ScalingMode(Enum): NVTE_DELAYED_TENSOR_SCALING = 0 NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 2 - NVTE_NO_SCALING = 3 + NVTE_INVALID_SCALING = 4 + NVTE_NO_SCALING = 5 def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. From 51f7b29a9f8711202949e169bc6765e0e7a7f004 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 4 Apr 2025 04:50:42 +0000 Subject: [PATCH 38/38] Skip tests on Lovelace Signed-off-by: Tim Moon --- tests/pytorch/test_float8blockwisetensor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 316842b4f3..d030426b74 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -12,11 +12,11 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, Float8BlockwiseQTensor, ) +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex # PyTorch tensor dtypes @@ -42,11 +42,12 @@ def _to_list(x: Union[Iterable, Any]) -> List: # Types that can be interpreted as tensor dims DimsType = Union[Iterable[int], int] -# Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# TODO replace with call to fp8.py when recipe added. +recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 +reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) class TestFloat8BlockwiseTensor: @staticmethod