Skip to content
40 changes: 40 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,43 @@ def test_nvfp4_quantization_noncontiguous_inputs(
torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)

torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)


@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(32, 128),
],
)
@pytest.mark.parametrize(
"with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"]
)
def test_nvfp4_3d_shape_quantization(
M: int,
N: int,
with_2d_quantization: bool,
):
te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
# Input
x = torch.randn((M, 4, N), dtype=torch.bfloat16, device=device)

# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=with_2d_quantization,
)
q_x = nvfp4_quantizer(x)
x *= 2
nvfp4_quantizer.update_quantized(x, q_x)
assert q_x._rowwise_data is not None
assert len(q_x._rowwise_data.shape) == 3
assert q_x._columnwise_data is not None
assert len(q_x._columnwise_data.shape) == 2
33 changes: 31 additions & 2 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

#include <pybind.h>

#include <functional>
#include <numeric>

#include "common.h"
#include "pybind.h"
#include "torch/torch.h"
Expand Down Expand Up @@ -1260,6 +1263,29 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_unquantized_tensor_w
return {std::move(out_cpp), std::move(out_py)};
}

/**
* @brief Compress an N-D shape into a 2-D shape by flattening all but the last dimension.
*
* This utility is intended for comparing N-dimensional tensor shapes in a 2D space:
* it multiplies (flattens) every dimension except the final one into a single leading
* dimension, and keeps the last dimension unchanged.
*
* Example: [d0, d1, d2, ..., d{n-2}, d{n-1}] -> [d0*d1*...*d{n-2}, d{n-1}]
*
* If the input has 2 or fewer dimensions, it is returned unchanged.
*/
std::vector<size_t> compressShapeTo2D(const std::vector<size_t>& data) {
// If 2 or fewer elements, return as-is
if (data.size() <= 2) {
return data;
}
// Multiply all elements except the last
size_t product = std::accumulate(data.begin(), data.end() - 1, static_cast<size_t>(1),
std::multiplies<size_t>());
// Return new vector of size 2: {product, last}
return std::vector<size_t>{product, data.back()};
}

std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor.");
Expand Down Expand Up @@ -1289,8 +1315,11 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true);
if (rowwise_data) {
auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
auto expected_shape_2d = compressShapeTo2D(expected_shape);
auto shape_2d = compressShapeTo2D(shape);
NVTE_CHECK(shape_2d == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d,
") and column-wise data (2D shape=", shape_2d, ") do not match");
shape = expected_shape;
}
} else { // Already checked columnwise_data_tensor == true
shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
Expand Down