From 93b48c9ad3688c70dbc82c46d16b6bbb39f5775f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 7 Oct 2025 17:56:16 -0700 Subject: [PATCH 01/14] PR6: exposing python examples --- python/python_direct/ops.cpp | 34 ++++ python/python_direct/python_translate.cpp | 9 + tests/python/direct/test_narrow_precision.py | 180 ++++++++++++++++++ tests/python/direct_utils/narrow_precision.py | 8 +- 4 files changed, 227 insertions(+), 4 deletions(-) diff --git a/python/python_direct/ops.cpp b/python/python_direct/ops.cpp index 4b51037077e..abf2ac31449 100644 --- a/python/python_direct/ops.cpp +++ b/python/python_direct/ops.cpp @@ -3188,6 +3188,40 @@ Notes - The weight tensor must be exactly 2D. - All optional parameters must be scalar values when provided. - This operation is equivalent to PyTorch's torch.nn.functional.embedding. +)", + py::return_value_policy::reference); + ops.def( + "preprocess_grouped_matmul_input_sf", + [](TensorView* input, + TensorView* input_offsets, + TensorView* output_offsets) -> decltype(auto) { + return preprocessGroupedMatmulInputSf( + input, + input_offsets, + output_offsets, + BlockScalingFactorLayout::Block128x4); + }, + py::arg("input"), + py::arg("input_offsets"), + py::arg("output_offsets"), + R"( +Layout operation to apply per group swizzle & padding for grouped matmul block scaling factor for activation. + +Parameters +---------- +input : TensorView + A 2D tensor containing blockwise scaling factor +input_offsets: TensorView + A 1D tensor with length as (1 + number of groups). + Its value notes the offsets of the starting token in each group, where the last entry contains the total number of token +output_offsets: TensorView + A 1D tensor with length as (1 + number of groups). + Its value notes the offsets of the starting token in each group at the output tensor view. + +Returns +------- +TensorView + A tensor with proper swizzle & padding in memory. Note that the actual padding in buffer is not represented by the size/stride of the output tensor. )", py::return_value_policy::reference); } diff --git a/python/python_direct/python_translate.cpp b/python/python_direct/python_translate.cpp index 50d95de5123..77e113ce397 100644 --- a/python/python_direct/python_translate.cpp +++ b/python/python_direct/python_translate.cpp @@ -1615,6 +1615,15 @@ class PythonTranslator : public OptInConstDispatch { {eop->out()}); } + void handle(const PreprocessGroupedMatmulInputSf* layout_op) final { + NVF_ERROR(layout_op != nullptr); + visited_vals_.insert(layout_op->output(0)); + printer_.generateOperation( + "fd.ops.preprocess_grouped_matmul_input_sf", + {layout_op->in()->as(), layout_op->inputOffsets(), layout_op->outputOffsets()}, + {layout_op->out()}); + } + private: //! Convert CPP values to python syntax. PythonPrinter printer_; diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 570e65d44ca..2826642e2d9 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -12,6 +12,7 @@ from nvfuser_direct.pytorch_utils import torch_dtype_to_nvfuser_dtype from python.direct_utils import ( FLOAT4_E2M1_MAX, + FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX, pytorch_nvfp4_quantize, is_pre_blackwell, @@ -278,3 +279,182 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: ) assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) + + +@pytest.mark.skipif( + is_pre_blackwell(), reason="Only supported on blackwell and newer devices." +) +#@pytest.mark.parametrize("config", [[1024, 128, 16*9]]) +@pytest.mark.parametrize("config", [[1024, 128, 256]]) +@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) +def test_layout_op_and_cutlass_nvfp4_grouped_mm( + nvfuser_direct_test, + config, + tokens_per_expert_neg_one, + out_dtype, +): + INPUT_DTYPE = torch.uint8 + BLOCK_SIZE = 16 + + # k dimension is multiple of 128 to avoid padding + m, n, k = config + tokens_per_expert = tokens_per_expert_neg_one + tokens_per_expert.append(m - sum(tokens_per_expert)) + g = len(tokens_per_expert) + + mat1 = torch.testing.make_tensor((m, k), dtype=torch.float32, device="cuda:0") + # format is g, n, k instead of g, k, n + mat2 = torch.testing.make_tensor((g, n, k), dtype=torch.float32, device="cuda:0") + + offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") + blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") + problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0") + + # prepare quantization for mat2 + mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0") + scale2 = torch.empty( + (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0" + ) + + acc_tokens = 0 + rounded_acc_tokens = 0 + mat2_scaled = torch.empty( + (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0" + ) + + for i in range(g): + mat2_gs[i] = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max() + offsets[i] = acc_tokens + blockscale_offsets[i] = rounded_acc_tokens + acc_tokens += tokens_per_expert[i] + # Note: we technically don't need to round up, since k is perfectly sized. + rounded_acc_tokens += round_up(tokens_per_expert[i], 128) + + problem_sizes[i][0] = tokens_per_expert[i] + problem_sizes[i][1] = n + problem_sizes[i][2] = k + + scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], mat2_gs[i]) + mat2_scaled[i] = scaled_mat2_i + scale2[i] = linear_to_swizzled_128_4(bs_mat2_i) + + + ab_strides = torch.full((g,), k, dtype=torch.int64, device="cuda:0") + c_strides = torch.full((g,), n, dtype=torch.int64, device="cuda:0") + + def nvfuser_fusion_id0(fd: FusionDefinition) -> None: + mat1 = fd.define_tensor( + shape=[-1, -1], + contiguity=True, + dtype=DataType.Float, + is_cpu=False, + ) + mat2 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=True, + dtype=DataType.Float4_e2m1fn, + is_cpu=False, + stride_order=[2, 0, 1], + ) + scale2 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=True, + dtype=DataType.Float8_e4m3fn, + is_cpu=False, + ) + alpha = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False + ) + problem_sizes = fd.define_tensor( + shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + offsets = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + blockscale_offsets = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + # TODO: fix dynamic shape in issue https://github.com/NVIDIA/Fuser/issues/5199 + m_size = m + k_size = k + k_tile_size = k_size // 16; + # m_size = fd.ops.size(mat1, 0) + # k_size = fd.ops.size(mat1, 1) + # k_tile_size = fd.ops.div(k_size, 16) + # using primitive operations to handle quantization + reshaped_mat1 = fd.ops.reshape(mat1, [m_size, k_tile_size, 16]) + + scale1 = fd.ops.abs(reshaped_mat1) + scale1 = fd.ops.max(scale1, 2) + scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX) + scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX) + + broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True]) + reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1) + reshaped_scaled_mat1 = fd.ops.clamp(reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX) + + scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size]) + # should I clamp here before cast?! + fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn) + fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn) + # NOTE: I need to add an entry for translation rule to print out this + layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, offsets, blockscale_offsets) + # NOTE: it's not working with the grouped_mm. Looks like segmentation is a bit different. But I think it's also exposing some dependency issue above. + out = fd.ops.cutlass_nvfp4_grouped_mm( + fp4_mat1, + mat2, + layout_fp8_scale1, + scale2, + alpha, + problem_sizes, + offsets, + blockscale_offsets, + DataType.BFloat16, + ) + fd.add_output(out) + + inputs = [ + mat1, + mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2), + scale2, + mat2_gs, + problem_sizes, + offsets, + blockscale_offsets, + ] + + o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) + + # quantization for activation is needed for reference. + # note: following sglang implementation, not computing global scaling factor for mat1 + # similarly, we don't need to apply mat1_gs to alpha + mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0") + mat1_fp4, scale1 = activation_scale_to_nvfp4( + mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE + ) + o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0") + for i in range(g): + l = offsets[i] + l_sf = blockscale_offsets[i] + if i == g - 1: + r = m + else: + r = offsets[i + 1] + r_sf = round_up(tokens_per_expert[i], 128) + l_sf + # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel. + # This triggers a cublas invalid value error. + o_decomposed_ref[l:r] = ( + torch._scaled_mm( + mat1_fp4[l:r], + mat2_scaled[i].transpose(-1, -2), + scale1[l_sf:r_sf], + scale2[i], + None, + None, + torch.bfloat16, + ) + * mat2_gs[i] + ) + + assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) diff --git a/tests/python/direct_utils/narrow_precision.py b/tests/python/direct_utils/narrow_precision.py index 6b2537d20b7..b1913e362ed 100644 --- a/tests/python/direct_utils/narrow_precision.py +++ b/tests/python/direct_utils/narrow_precision.py @@ -137,11 +137,11 @@ def pytorch_nvfp4_quantize(a, a_global_scale): block_scale_fp32 = (max_abs / FLOAT4_E2M1_MAX).float() scaled_block_scale_fp32 = block_scale_fp32 * a_global_scale - scaled_block_scale_fp8 = torch.clamp( + scaled_block_scale_fp32 = torch.clamp( scaled_block_scale_fp32, min=FLOAT8_E4M3_EPS, max=FLOAT8_E4M3_MAX - ).to(torch.float8_e4m3fn) - scaled_block_scale_fp8_fp32 = scaled_block_scale_fp8.to(torch.float) - total_scale = scaled_block_scale_fp8_fp32 / a_global_scale + ) + scaled_block_scale_fp8 = scaled_block_scale_fp32.to(torch.float8_e4m3fn) + total_scale = scaled_block_scale_fp32 / a_global_scale a_scaled = a_fp32 / total_scale.unsqueeze(-1) a_scaled = torch.clamp(a_scaled, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX) a_scaled = a_scaled.view(original_shape) From 83c284dd4308d0a0f6ce32f29ed5babdfccee8f0 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 7 Oct 2025 17:57:41 -0700 Subject: [PATCH 02/14] not sure if I needed this one --- csrc/dynamic_transform.cpp | 14 ++++++++++++++ csrc/ir/internal_base_nodes.h | 8 +++++--- csrc/ir/nodes.cpp | 9 ++++++--- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 7aba9fef614..2e4dfa962bb 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -848,6 +849,19 @@ void DynamicTransformConcretizer::concretize() { OptOutMutator::dispatchMutate(stmt); } + // OptOutMutator only updates ID, but not exprs on its extent. Concretization on the allocation domain of layout op output needs to be manually replaced, because they are not connect by ID ops. + auto exprs = info_->fusion()->exprs(); + for (auto* layout_op : ir_utils::filterByType(exprs)) { + auto* out_tv = layout_op->out()->as(); + std::vector logical_dom = TensorDomain::noReductions(out_tv->getLogicalDomain()); + std::vector alloc_dom = + layoutAllocationDomain( + logical_dom, + layout_op->g(), + layout_op->layout()); + out_tv->domain()->setAllocationDomain(alloc_dom, true, true); + } + for (Val* outp : info_->fusion()->outputs()) { Val* new_outp = maybeMutated(outp); if (new_outp != outp) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 83410dccb3e..f0efdad42aa 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -665,17 +665,19 @@ class NVF_API TensorDomain : public Val { // accordingly. NVF_API void setAllocationDomain( std::vector new_allocation_domain, - std::vector> new_contiguity); + std::vector> new_contiguity, + bool skip_validation = false); // Similar to the previous one, but with new contiguity filled with all true // or all false. void setAllocationDomain( std::vector new_allocation_domain, - bool new_contiguity) { + bool new_contiguity, + bool skip_validation = false) { auto contiguity_flags = getContiguityFilledWith(new_allocation_domain, new_contiguity); setAllocationDomain( - std::move(new_allocation_domain), std::move(contiguity_flags)); + std::move(new_allocation_domain), std::move(contiguity_flags), skip_validation); } // i here is int, as we want to accept negative value and ::size_type can be a diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 2971b542486..95358ec29be 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4083,11 +4083,14 @@ void TensorDomain::setAlternateLoopDomain( void TensorDomain::setAllocationDomain( std::vector new_allocation_domain, - std::vector> new_contiguity) { + std::vector> new_contiguity, + bool skip_validation) { validateContiguity(new_allocation_domain, new_contiguity); - ir_utils::validateDomainEquivalence( - logical_domain_, new_allocation_domain, additional_ids_); + if (!skip_validation) { + ir_utils::validateDomainEquivalence( + logical_domain_, new_allocation_domain, additional_ids_); + } allocation_domain_ = std::move(new_allocation_domain); contiguity_ = std::move(new_contiguity); From 3ef3fc0614968a83a45abdb1d84d47ebb287e215 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 8 Oct 2025 11:24:21 -0700 Subject: [PATCH 03/14] clangformat --- csrc/dynamic_transform.cpp | 21 ++++++++++++--------- csrc/ir/internal_base_nodes.h | 4 +++- python/python_direct/ops.cpp | 8 ++++---- python/python_direct/python_translate.cpp | 4 +++- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 2e4dfa962bb..0d391ec9c03 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -14,8 +14,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -849,16 +849,19 @@ void DynamicTransformConcretizer::concretize() { OptOutMutator::dispatchMutate(stmt); } - // OptOutMutator only updates ID, but not exprs on its extent. Concretization on the allocation domain of layout op output needs to be manually replaced, because they are not connect by ID ops. + // OptOutMutator only updates ID, but not exprs on its extent. Concretization + // on the allocation domain of layout op output needs to be manually replaced, + // because they are not connect by ID ops. auto exprs = info_->fusion()->exprs(); - for (auto* layout_op : ir_utils::filterByType(exprs)) { + for (auto* layout_op : + ir_utils::filterByType(exprs)) { auto* out_tv = layout_op->out()->as(); - std::vector logical_dom = TensorDomain::noReductions(out_tv->getLogicalDomain()); - std::vector alloc_dom = - layoutAllocationDomain( - logical_dom, - layout_op->g(), - layout_op->layout()); + std::vector logical_dom = + TensorDomain::noReductions(out_tv->getLogicalDomain()); + std::vector alloc_dom = layoutAllocationDomain( + logical_dom, layout_op->g(), layout_op->layout()); + // skip validation because allocation domain doesn't converge to logical + // domain. out_tv->domain()->setAllocationDomain(alloc_dom, true, true); } diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index f0efdad42aa..2832fd442e5 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -677,7 +677,9 @@ class NVF_API TensorDomain : public Val { auto contiguity_flags = getContiguityFilledWith(new_allocation_domain, new_contiguity); setAllocationDomain( - std::move(new_allocation_domain), std::move(contiguity_flags), skip_validation); + std::move(new_allocation_domain), + std::move(contiguity_flags), + skip_validation); } // i here is int, as we want to accept negative value and ::size_type can be a diff --git a/python/python_direct/ops.cpp b/python/python_direct/ops.cpp index abf2ac31449..bbe2a5f4599 100644 --- a/python/python_direct/ops.cpp +++ b/python/python_direct/ops.cpp @@ -3212,11 +3212,11 @@ Parameters input : TensorView A 2D tensor containing blockwise scaling factor input_offsets: TensorView - A 1D tensor with length as (1 + number of groups). - Its value notes the offsets of the starting token in each group, where the last entry contains the total number of token + A 1D tensor with length as `number of groups`. + Its value notes the offsets of the starting token in each group for the input tensor view output_offsets: TensorView - A 1D tensor with length as (1 + number of groups). - Its value notes the offsets of the starting token in each group at the output tensor view. + A 1D tensor with length as `number of groups`. + Its value notes the offsets of the starting token in each group for the output tensor view. Returns ------- diff --git a/python/python_direct/python_translate.cpp b/python/python_direct/python_translate.cpp index 77e113ce397..2e903b2eca6 100644 --- a/python/python_direct/python_translate.cpp +++ b/python/python_direct/python_translate.cpp @@ -1620,7 +1620,9 @@ class PythonTranslator : public OptInConstDispatch { visited_vals_.insert(layout_op->output(0)); printer_.generateOperation( "fd.ops.preprocess_grouped_matmul_input_sf", - {layout_op->in()->as(), layout_op->inputOffsets(), layout_op->outputOffsets()}, + {layout_op->in()->as(), + layout_op->inputOffsets(), + layout_op->outputOffsets()}, {layout_op->out()}); } From 1a7b468c5d2440110289dbdd0460d6a65110693e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 8 Oct 2025 11:46:34 -0700 Subject: [PATCH 04/14] cleaning python test --- tests/python/direct/test_narrow_precision.py | 33 +++++++++----------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 2826642e2d9..8047a8cfd20 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -187,9 +187,6 @@ def test_cutlass_nvfp4_grouped_mm( mat1_ref, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE ) - ab_strides = torch.full((g,), k, dtype=torch.int64, device="cuda:0") - c_strides = torch.full((g,), n, dtype=torch.int64, device="cuda:0") - def nvfuser_fusion_id0(fd: FusionDefinition) -> None: mat1 = fd.define_tensor( shape=[-1, -1], @@ -281,10 +278,10 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) +# TODO: update reference implementation to support padding on k @pytest.mark.skipif( is_pre_blackwell(), reason="Only supported on blackwell and newer devices." ) -#@pytest.mark.parametrize("config", [[1024, 128, 16*9]]) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16]) @@ -297,8 +294,9 @@ def test_layout_op_and_cutlass_nvfp4_grouped_mm( INPUT_DTYPE = torch.uint8 BLOCK_SIZE = 16 - # k dimension is multiple of 128 to avoid padding + # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor m, n, k = config + assert k % 64 == 0 tokens_per_expert = tokens_per_expert_neg_one tokens_per_expert.append(m - sum(tokens_per_expert)) g = len(tokens_per_expert) @@ -324,7 +322,7 @@ def test_layout_op_and_cutlass_nvfp4_grouped_mm( ) for i in range(g): - mat2_gs[i] = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max() + global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max() offsets[i] = acc_tokens blockscale_offsets[i] = rounded_acc_tokens acc_tokens += tokens_per_expert[i] @@ -335,14 +333,11 @@ def test_layout_op_and_cutlass_nvfp4_grouped_mm( problem_sizes[i][1] = n problem_sizes[i][2] = k - scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], mat2_gs[i]) + scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf) + mat2_gs[i] = 1.0 / global_sf mat2_scaled[i] = scaled_mat2_i scale2[i] = linear_to_swizzled_128_4(bs_mat2_i) - - ab_strides = torch.full((g,), k, dtype=torch.int64, device="cuda:0") - c_strides = torch.full((g,), n, dtype=torch.int64, device="cuda:0") - def nvfuser_fusion_id0(fd: FusionDefinition) -> None: mat1 = fd.define_tensor( shape=[-1, -1], @@ -376,31 +371,33 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False ) # TODO: fix dynamic shape in issue https://github.com/NVIDIA/Fuser/issues/5199 - m_size = m - k_size = k - k_tile_size = k_size // 16; # m_size = fd.ops.size(mat1, 0) # k_size = fd.ops.size(mat1, 1) # k_tile_size = fd.ops.div(k_size, 16) + # use static shape as a temporary WAR. + m_size = m + k_size = k + k_tile_size = k_size // 16; # using primitive operations to handle quantization reshaped_mat1 = fd.ops.reshape(mat1, [m_size, k_tile_size, 16]) + # quantization math to compute block scaling factor scale1 = fd.ops.abs(reshaped_mat1) scale1 = fd.ops.max(scale1, 2) scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX) scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX) - broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True]) reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1) reshaped_scaled_mat1 = fd.ops.clamp(reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX) scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size]) - # should I clamp here before cast?! + + # cast the quantized tv and block sf to proper dtype fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn) fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn) - # NOTE: I need to add an entry for translation rule to print out this + + # swizzle & pad block sf layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, offsets, blockscale_offsets) - # NOTE: it's not working with the grouped_mm. Looks like segmentation is a bit different. But I think it's also exposing some dependency issue above. out = fd.ops.cutlass_nvfp4_grouped_mm( fp4_mat1, mat2, From 2c2bc633904492802419ef153e12eefbb69f5a83 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 8 Oct 2025 12:05:20 -0700 Subject: [PATCH 05/14] clangformat, add war in python tests --- tests/python/direct/test_narrow_precision.py | 17 +++++++++++------ tests/python/utils/utils.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 8047a8cfd20..630c9d6dd62 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -131,7 +131,6 @@ def test_cutlass_nvfp4_grouped_mm( tokens_per_expert_neg_one, out_dtype, ): - INPUT_DTYPE = torch.uint8 BLOCK_SIZE = 16 # k dimension is multiple of 128 to avoid padding @@ -291,7 +290,6 @@ def test_layout_op_and_cutlass_nvfp4_grouped_mm( tokens_per_expert_neg_one, out_dtype, ): - INPUT_DTYPE = torch.uint8 BLOCK_SIZE = 16 # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor @@ -377,7 +375,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: # use static shape as a temporary WAR. m_size = m k_size = k - k_tile_size = k_size // 16; + k_tile_size = k_size // 16 # using primitive operations to handle quantization reshaped_mat1 = fd.ops.reshape(mat1, [m_size, k_tile_size, 16]) @@ -388,7 +386,9 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX) broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True]) reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1) - reshaped_scaled_mat1 = fd.ops.clamp(reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX) + reshaped_scaled_mat1 = fd.ops.clamp( + reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX + ) scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size]) @@ -397,7 +397,9 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn) # swizzle & pad block sf - layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, offsets, blockscale_offsets) + layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf( + fp8_scale1, offsets, blockscale_offsets + ) out = fd.ops.cutlass_nvfp4_grouped_mm( fp4_mat1, mat2, @@ -421,7 +423,10 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: blockscale_offsets, ] - o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) + # FIXME: force indexing to use IdModel indexer to avoid indexing error. + # see issue: https://github.com/NVIDIA/Fuser/issues/5200 + with set_env({"NVFUSER_ENABLE" : "id_model(all)"}): + o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) # quantization for activation is needed for reference. # note: following sglang implementation, not computing global scaling factor for mat1 diff --git a/tests/python/utils/utils.py b/tests/python/utils/utils.py index e2611e501a6..3db0173e791 100644 --- a/tests/python/utils/utils.py +++ b/tests/python/utils/utils.py @@ -343,3 +343,16 @@ def exec_nvfuser( check_cpp_translation(out, fd, inputs_cloned, supports_segmentation) ) return out, fd + +@contextmanager +def set_env(**environ): + """ + Override environment variable + """ + old_environ = dict(os.environ) + os.environ.update(environ) + try: + yield + finally: + os.environ.clear() + os.environ.update(old_environ) From e04076cc1da27acaf999a888c6dee08a4ba22c28 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 8 Oct 2025 13:22:04 -0700 Subject: [PATCH 06/14] black --- tests/python/direct/test_narrow_precision.py | 2 +- tests/python/utils/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 630c9d6dd62..0ed5c9f0aac 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -425,7 +425,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: # FIXME: force indexing to use IdModel indexer to avoid indexing error. # see issue: https://github.com/NVIDIA/Fuser/issues/5200 - with set_env({"NVFUSER_ENABLE" : "id_model(all)"}): + with set_env({"NVFUSER_ENABLE": "id_model(all)"}): o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) # quantization for activation is needed for reference. diff --git a/tests/python/utils/utils.py b/tests/python/utils/utils.py index 3db0173e791..f6d0ddadc3e 100644 --- a/tests/python/utils/utils.py +++ b/tests/python/utils/utils.py @@ -344,6 +344,7 @@ def exec_nvfuser( ) return out, fd + @contextmanager def set_env(**environ): """ From 2088329910df8623e206ab902327e38c91c16732 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 8 Oct 2025 13:53:20 -0700 Subject: [PATCH 07/14] fixing python tests --- tests/python/direct/test_narrow_precision.py | 3 ++- tests/python/utils/utils.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 0ed5c9f0aac..a245bdbe16e 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -10,6 +10,7 @@ DataType, ) from nvfuser_direct.pytorch_utils import torch_dtype_to_nvfuser_dtype +from python.utils import set_env from python.direct_utils import ( FLOAT4_E2M1_MAX, FLOAT8_E4M3_EPS, @@ -425,7 +426,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: # FIXME: force indexing to use IdModel indexer to avoid indexing error. # see issue: https://github.com/NVIDIA/Fuser/issues/5200 - with set_env({"NVFUSER_ENABLE": "id_model(all)"}): + with set_env(NVFUSER_ENABLE="id_model(all)"): o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) # quantization for activation is needed for reference. diff --git a/tests/python/utils/utils.py b/tests/python/utils/utils.py index f6d0ddadc3e..34cd3e5e12e 100644 --- a/tests/python/utils/utils.py +++ b/tests/python/utils/utils.py @@ -9,6 +9,7 @@ import tempfile import torch import pytest +from contextlib import contextmanager from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase from looseversion import LooseVersion From 9e74121cdb51ed8329e0faecb04e9954d8f8d505 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 17:36:58 -0700 Subject: [PATCH 08/14] revert setAllocationDomain API change --- csrc/dynamic_transform.cpp | 2 +- csrc/ir/internal_base_nodes.h | 10 +++------- csrc/ir/nodes.cpp | 9 +++------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 0d391ec9c03..36287bfe9c6 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -862,7 +862,7 @@ void DynamicTransformConcretizer::concretize() { logical_dom, layout_op->g(), layout_op->layout()); // skip validation because allocation domain doesn't converge to logical // domain. - out_tv->domain()->setAllocationDomain(alloc_dom, true, true); + out_tv->domain()->setAllocationDomain(alloc_dom, /*new_contiguity=*/true); } for (Val* outp : info_->fusion()->outputs()) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 2832fd442e5..83410dccb3e 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -665,21 +665,17 @@ class NVF_API TensorDomain : public Val { // accordingly. NVF_API void setAllocationDomain( std::vector new_allocation_domain, - std::vector> new_contiguity, - bool skip_validation = false); + std::vector> new_contiguity); // Similar to the previous one, but with new contiguity filled with all true // or all false. void setAllocationDomain( std::vector new_allocation_domain, - bool new_contiguity, - bool skip_validation = false) { + bool new_contiguity) { auto contiguity_flags = getContiguityFilledWith(new_allocation_domain, new_contiguity); setAllocationDomain( - std::move(new_allocation_domain), - std::move(contiguity_flags), - skip_validation); + std::move(new_allocation_domain), std::move(contiguity_flags)); } // i here is int, as we want to accept negative value and ::size_type can be a diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 95358ec29be..2971b542486 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4083,14 +4083,11 @@ void TensorDomain::setAlternateLoopDomain( void TensorDomain::setAllocationDomain( std::vector new_allocation_domain, - std::vector> new_contiguity, - bool skip_validation) { + std::vector> new_contiguity) { validateContiguity(new_allocation_domain, new_contiguity); - if (!skip_validation) { - ir_utils::validateDomainEquivalence( - logical_domain_, new_allocation_domain, additional_ids_); - } + ir_utils::validateDomainEquivalence( + logical_domain_, new_allocation_domain, additional_ids_); allocation_domain_ = std::move(new_allocation_domain); contiguity_ = std::move(new_contiguity); From bddd525922384aef1b567d1c83658bd7d10d4c8f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 17:41:15 -0700 Subject: [PATCH 09/14] addressing review comments --- python/python_direct/ops.cpp | 4 ++-- tests/python/direct/test_narrow_precision.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/python_direct/ops.cpp b/python/python_direct/ops.cpp index bbe2a5f4599..a89c48aa082 100644 --- a/python/python_direct/ops.cpp +++ b/python/python_direct/ops.cpp @@ -3205,7 +3205,7 @@ Notes py::arg("input_offsets"), py::arg("output_offsets"), R"( -Layout operation to apply per group swizzle & padding for grouped matmul block scaling factor for activation. +Layout operation to apply per group swizzle & padding to the block scaling factor of the input activations to grouped matmul. Parameters ---------- @@ -3221,7 +3221,7 @@ output_offsets: TensorView Returns ------- TensorView - A tensor with proper swizzle & padding in memory. Note that the actual padding in buffer is not represented by the size/stride of the output tensor. + A tensor with proper swizzle & padding in memory. Note that the actual padding in buffer is not represented by the size/stride of the output tensor. )", py::return_value_policy::reference); } diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index a245bdbe16e..71d1f60702c 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -280,7 +280,7 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: # TODO: update reference implementation to support padding on k @pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." + not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" ) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) From 2598dede7db4a85f7e91c5d887108b562a8b9c42 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 14 Oct 2025 01:56:49 -0700 Subject: [PATCH 10/14] Revert "revert setAllocationDomain API change" This reverts commit eba5fe1bdbc9bf579f5487c122b6dd68d704342a. --- csrc/dynamic_transform.cpp | 2 +- csrc/ir/internal_base_nodes.h | 10 +++++++--- csrc/ir/nodes.cpp | 9 ++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 36287bfe9c6..0d391ec9c03 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -862,7 +862,7 @@ void DynamicTransformConcretizer::concretize() { logical_dom, layout_op->g(), layout_op->layout()); // skip validation because allocation domain doesn't converge to logical // domain. - out_tv->domain()->setAllocationDomain(alloc_dom, /*new_contiguity=*/true); + out_tv->domain()->setAllocationDomain(alloc_dom, true, true); } for (Val* outp : info_->fusion()->outputs()) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 83410dccb3e..2832fd442e5 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -665,17 +665,21 @@ class NVF_API TensorDomain : public Val { // accordingly. NVF_API void setAllocationDomain( std::vector new_allocation_domain, - std::vector> new_contiguity); + std::vector> new_contiguity, + bool skip_validation = false); // Similar to the previous one, but with new contiguity filled with all true // or all false. void setAllocationDomain( std::vector new_allocation_domain, - bool new_contiguity) { + bool new_contiguity, + bool skip_validation = false) { auto contiguity_flags = getContiguityFilledWith(new_allocation_domain, new_contiguity); setAllocationDomain( - std::move(new_allocation_domain), std::move(contiguity_flags)); + std::move(new_allocation_domain), + std::move(contiguity_flags), + skip_validation); } // i here is int, as we want to accept negative value and ::size_type can be a diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 2971b542486..95358ec29be 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4083,11 +4083,14 @@ void TensorDomain::setAlternateLoopDomain( void TensorDomain::setAllocationDomain( std::vector new_allocation_domain, - std::vector> new_contiguity) { + std::vector> new_contiguity, + bool skip_validation) { validateContiguity(new_allocation_domain, new_contiguity); - ir_utils::validateDomainEquivalence( - logical_domain_, new_allocation_domain, additional_ids_); + if (!skip_validation) { + ir_utils::validateDomainEquivalence( + logical_domain_, new_allocation_domain, additional_ids_); + } allocation_domain_ = std::move(new_allocation_domain); contiguity_ = std::move(new_contiguity); From b1da34be9cd37c107f567712496d5defc907614f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 14 Oct 2025 02:01:41 -0700 Subject: [PATCH 11/14] wip --- tests/python/direct/test_narrow_precision.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 71d1f60702c..5f25b67140e 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -279,6 +279,9 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: # TODO: update reference implementation to support padding on k +@pytest.mark.skipif( + is_pre_blackwell(), reason="Only supported on blackwell and newer devices." +) @pytest.mark.skipif( not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" ) From fe8da6207c6705be07bb5830920c5ecffa51f74b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 16 Oct 2025 13:22:01 -0700 Subject: [PATCH 12/14] revert unwanted changes --- csrc/dynamic_transform.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 0d391ec9c03..7aba9fef614 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -849,22 +848,6 @@ void DynamicTransformConcretizer::concretize() { OptOutMutator::dispatchMutate(stmt); } - // OptOutMutator only updates ID, but not exprs on its extent. Concretization - // on the allocation domain of layout op output needs to be manually replaced, - // because they are not connect by ID ops. - auto exprs = info_->fusion()->exprs(); - for (auto* layout_op : - ir_utils::filterByType(exprs)) { - auto* out_tv = layout_op->out()->as(); - std::vector logical_dom = - TensorDomain::noReductions(out_tv->getLogicalDomain()); - std::vector alloc_dom = layoutAllocationDomain( - logical_dom, layout_op->g(), layout_op->layout()); - // skip validation because allocation domain doesn't converge to logical - // domain. - out_tv->domain()->setAllocationDomain(alloc_dom, true, true); - } - for (Val* outp : info_->fusion()->outputs()) { Val* new_outp = maybeMutated(outp); if (new_outp != outp) { From ac54b62a4ab3c3d8fb83f94133c58040784b0377 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 20 Oct 2025 09:30:20 -0700 Subject: [PATCH 13/14] stripping test into separate file --- tests/python/direct/test_narrow_precision.py | 190 --------------- .../direct/test_with_id_model_indexer.py | 217 ++++++++++++++++++ 2 files changed, 217 insertions(+), 190 deletions(-) create mode 100644 tests/python/direct/test_with_id_model_indexer.py diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 5f25b67140e..0285258d6f6 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -10,10 +10,8 @@ DataType, ) from nvfuser_direct.pytorch_utils import torch_dtype_to_nvfuser_dtype -from python.utils import set_env from python.direct_utils import ( FLOAT4_E2M1_MAX, - FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX, pytorch_nvfp4_quantize, is_pre_blackwell, @@ -276,191 +274,3 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: ) assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) - - -# TODO: update reference implementation to support padding on k -@pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." -) -@pytest.mark.skipif( - not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" -) -@pytest.mark.parametrize("config", [[1024, 128, 256]]) -@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) -def test_layout_op_and_cutlass_nvfp4_grouped_mm( - nvfuser_direct_test, - config, - tokens_per_expert_neg_one, - out_dtype, -): - BLOCK_SIZE = 16 - - # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor - m, n, k = config - assert k % 64 == 0 - tokens_per_expert = tokens_per_expert_neg_one - tokens_per_expert.append(m - sum(tokens_per_expert)) - g = len(tokens_per_expert) - - mat1 = torch.testing.make_tensor((m, k), dtype=torch.float32, device="cuda:0") - # format is g, n, k instead of g, k, n - mat2 = torch.testing.make_tensor((g, n, k), dtype=torch.float32, device="cuda:0") - - offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") - blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") - problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0") - - # prepare quantization for mat2 - mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0") - scale2 = torch.empty( - (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0" - ) - - acc_tokens = 0 - rounded_acc_tokens = 0 - mat2_scaled = torch.empty( - (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0" - ) - - for i in range(g): - global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max() - offsets[i] = acc_tokens - blockscale_offsets[i] = rounded_acc_tokens - acc_tokens += tokens_per_expert[i] - # Note: we technically don't need to round up, since k is perfectly sized. - rounded_acc_tokens += round_up(tokens_per_expert[i], 128) - - problem_sizes[i][0] = tokens_per_expert[i] - problem_sizes[i][1] = n - problem_sizes[i][2] = k - - scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf) - mat2_gs[i] = 1.0 / global_sf - mat2_scaled[i] = scaled_mat2_i - scale2[i] = linear_to_swizzled_128_4(bs_mat2_i) - - def nvfuser_fusion_id0(fd: FusionDefinition) -> None: - mat1 = fd.define_tensor( - shape=[-1, -1], - contiguity=True, - dtype=DataType.Float, - is_cpu=False, - ) - mat2 = fd.define_tensor( - shape=[-1, -1, -1], - contiguity=True, - dtype=DataType.Float4_e2m1fn, - is_cpu=False, - stride_order=[2, 0, 1], - ) - scale2 = fd.define_tensor( - shape=[-1, -1, -1], - contiguity=True, - dtype=DataType.Float8_e4m3fn, - is_cpu=False, - ) - alpha = fd.define_tensor( - shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False - ) - problem_sizes = fd.define_tensor( - shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False - ) - offsets = fd.define_tensor( - shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False - ) - blockscale_offsets = fd.define_tensor( - shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False - ) - # TODO: fix dynamic shape in issue https://github.com/NVIDIA/Fuser/issues/5199 - # m_size = fd.ops.size(mat1, 0) - # k_size = fd.ops.size(mat1, 1) - # k_tile_size = fd.ops.div(k_size, 16) - # use static shape as a temporary WAR. - m_size = m - k_size = k - k_tile_size = k_size // 16 - # using primitive operations to handle quantization - reshaped_mat1 = fd.ops.reshape(mat1, [m_size, k_tile_size, 16]) - - # quantization math to compute block scaling factor - scale1 = fd.ops.abs(reshaped_mat1) - scale1 = fd.ops.max(scale1, 2) - scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX) - scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX) - broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True]) - reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1) - reshaped_scaled_mat1 = fd.ops.clamp( - reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX - ) - - scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size]) - - # cast the quantized tv and block sf to proper dtype - fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn) - fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn) - - # swizzle & pad block sf - layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf( - fp8_scale1, offsets, blockscale_offsets - ) - out = fd.ops.cutlass_nvfp4_grouped_mm( - fp4_mat1, - mat2, - layout_fp8_scale1, - scale2, - alpha, - problem_sizes, - offsets, - blockscale_offsets, - DataType.BFloat16, - ) - fd.add_output(out) - - inputs = [ - mat1, - mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2), - scale2, - mat2_gs, - problem_sizes, - offsets, - blockscale_offsets, - ] - - # FIXME: force indexing to use IdModel indexer to avoid indexing error. - # see issue: https://github.com/NVIDIA/Fuser/issues/5200 - with set_env(NVFUSER_ENABLE="id_model(all)"): - o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) - - # quantization for activation is needed for reference. - # note: following sglang implementation, not computing global scaling factor for mat1 - # similarly, we don't need to apply mat1_gs to alpha - mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0") - mat1_fp4, scale1 = activation_scale_to_nvfp4( - mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE - ) - o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0") - for i in range(g): - l = offsets[i] - l_sf = blockscale_offsets[i] - if i == g - 1: - r = m - else: - r = offsets[i + 1] - r_sf = round_up(tokens_per_expert[i], 128) + l_sf - # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel. - # This triggers a cublas invalid value error. - o_decomposed_ref[l:r] = ( - torch._scaled_mm( - mat1_fp4[l:r], - mat2_scaled[i].transpose(-1, -2), - scale1[l_sf:r_sf], - scale2[i], - None, - None, - torch.bfloat16, - ) - * mat2_gs[i] - ) - - assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) diff --git a/tests/python/direct/test_with_id_model_indexer.py b/tests/python/direct/test_with_id_model_indexer.py new file mode 100644 index 00000000000..b413ce750ed --- /dev/null +++ b/tests/python/direct/test_with_id_model_indexer.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch + +from nvfuser_direct import ( + FusionDefinition, + DataType, +) +from python.utils import set_env +from python.direct_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_EPS, + FLOAT8_E4M3_MAX, + pytorch_nvfp4_quantize, + is_pre_blackwell, + microarchitecture_is_pre, + linear_to_swizzled_128_4, + round_up, + activation_scale_to_nvfp4, +) + +import pytest + + +# FIXME: this test needs to be merged back into test_narrow_precision.py. +# We have indexer issue: https://github.com/NVIDIA/Fuser/issues/5200, which +# forces the adoption of environment variable in order to avoid codegen +# assertion. Having this as a separate test file would avoid environment +# variable contamination from others. +@pytest.mark.skipif( + is_pre_blackwell(), reason="Only supported on blackwell and newer devices." +) +@pytest.mark.skipif( + not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" +) +@pytest.mark.parametrize("config", [[1024, 128, 256]]) +@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) +def test_layout_op_and_cutlass_nvfp4_grouped_mm( + nvfuser_direct_test, + config, + tokens_per_expert_neg_one, + out_dtype, +): + BLOCK_SIZE = 16 + + # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor + m, n, k = config + assert k % 64 == 0 + tokens_per_expert = tokens_per_expert_neg_one + tokens_per_expert.append(m - sum(tokens_per_expert)) + g = len(tokens_per_expert) + + mat1 = torch.testing.make_tensor((m, k), dtype=torch.float32, device="cuda:0") + # format is g, n, k instead of g, k, n + mat2 = torch.testing.make_tensor((g, n, k), dtype=torch.float32, device="cuda:0") + + offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") + blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") + problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0") + + # prepare quantization for mat2 + mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0") + scale2 = torch.empty( + (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0" + ) + + acc_tokens = 0 + rounded_acc_tokens = 0 + mat2_scaled = torch.empty( + (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0" + ) + + for i in range(g): + global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max() + offsets[i] = acc_tokens + blockscale_offsets[i] = rounded_acc_tokens + acc_tokens += tokens_per_expert[i] + # Note: we technically don't need to round up, since k is perfectly sized. + rounded_acc_tokens += round_up(tokens_per_expert[i], 128) + + problem_sizes[i][0] = tokens_per_expert[i] + problem_sizes[i][1] = n + problem_sizes[i][2] = k + + scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf) + mat2_gs[i] = 1.0 / global_sf + mat2_scaled[i] = scaled_mat2_i + scale2[i] = linear_to_swizzled_128_4(bs_mat2_i) + + def nvfuser_fusion_id0(fd: FusionDefinition) -> None: + mat1 = fd.define_tensor( + shape=[-1, -1], + contiguity=True, + dtype=DataType.Float, + is_cpu=False, + ) + mat2 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=True, + dtype=DataType.Float4_e2m1fn, + is_cpu=False, + stride_order=[2, 0, 1], + ) + scale2 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=True, + dtype=DataType.Float8_e4m3fn, + is_cpu=False, + ) + alpha = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False + ) + problem_sizes = fd.define_tensor( + shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + offsets = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + blockscale_offsets = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + # TODO: fix dynamic shape in issue https://github.com/NVIDIA/Fuser/issues/5199 + # m_size = fd.ops.size(mat1, 0) + # k_size = fd.ops.size(mat1, 1) + # k_tile_size = fd.ops.div(k_size, 16) + # use static shape as a temporary WAR. + m_size = m + k_size = k + k_tile_size = k_size // 16 + # using primitive operations to handle quantization + reshaped_mat1 = fd.ops.reshape(mat1, [m_size, k_tile_size, 16]) + + # quantization math to compute block scaling factor + scale1 = fd.ops.abs(reshaped_mat1) + scale1 = fd.ops.max(scale1, 2) + scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX) + scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX) + broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True]) + reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1) + reshaped_scaled_mat1 = fd.ops.clamp( + reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX + ) + + scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size]) + + # cast the quantized tv and block sf to proper dtype + fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn) + fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn) + + # swizzle & pad block sf + layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf( + fp8_scale1, offsets, blockscale_offsets + ) + out = fd.ops.cutlass_nvfp4_grouped_mm( + fp4_mat1, + mat2, + layout_fp8_scale1, + scale2, + alpha, + problem_sizes, + offsets, + blockscale_offsets, + DataType.BFloat16, + ) + fd.add_output(out) + + inputs = [ + mat1, + mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2), + scale2, + mat2_gs, + problem_sizes, + offsets, + blockscale_offsets, + ] + + # FIXME: force indexing to use IdModel indexer to avoid indexing error. + # see issue: https://github.com/NVIDIA/Fuser/issues/5200 + with set_env(NVFUSER_ENABLE="id_model(all)"): + o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) + + # quantization for activation is needed for reference. + # note: following sglang implementation, not computing global scaling factor for mat1 + # similarly, we don't need to apply mat1_gs to alpha + mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0") + mat1_fp4, scale1 = activation_scale_to_nvfp4( + mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE + ) + o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0") + for i in range(g): + l = offsets[i] + l_sf = blockscale_offsets[i] + if i == g - 1: + r = m + else: + r = offsets[i + 1] + r_sf = round_up(tokens_per_expert[i], 128) + l_sf + # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel. + # This triggers a cublas invalid value error. + o_decomposed_ref[l:r] = ( + torch._scaled_mm( + mat1_fp4[l:r], + mat2_scaled[i].transpose(-1, -2), + scale1[l_sf:r_sf], + scale2[i], + None, + None, + torch.bfloat16, + ) + * mat2_gs[i] + ) + + assert torch.allclose(o_decomposed_ref, o[0], atol=1e-2, rtol=1e-2) From ce77b621560b65252d8109596d9e0ee41c16cde9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 20 Oct 2025 15:29:50 -0700 Subject: [PATCH 14/14] relax checks --- tests/python/direct/test_cutlass_nvfp4_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/direct/test_cutlass_nvfp4_gemm.py b/tests/python/direct/test_cutlass_nvfp4_gemm.py index d5f63d3cf07..d1bdb32865a 100644 --- a/tests/python/direct/test_cutlass_nvfp4_gemm.py +++ b/tests/python/direct/test_cutlass_nvfp4_gemm.py @@ -161,7 +161,7 @@ def test_nvfp4_gemm_epilogue( # The percentage of mismatched values is 1%. nonzero = torch.count_nonzero(torch.ne(abs_diff, 0.0)) - assert (nonzero / abs_diff.numel()) < 0.01 + assert (nonzero / abs_diff.numel()) < 0.1 # Compare scale factors # rtol = epsilon = 2**(-3) for fp8_m4e3