From e52178df8e3b3a2729ee92a52e83394e459e3bc3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 13:35:22 -0700 Subject: [PATCH 1/8] adding tests --- tests/python/direct/test_python_frontend.py | 29 +++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index c6ddafce8bc..f89788f5dd3 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -23,6 +23,11 @@ verify_stride_order, ) +from python.direct_utils.narrow_precision import ( + pytorch_nvfp4_quantize, + unpack_fp4_bytes, +) + def test_basic(nvfuser_direct_test): inputs = [ @@ -2615,3 +2620,27 @@ def fusion_func(fd: FusionDefinition): RuntimeError, match="KernelExecutor does not support the Fusion provided." ): _ = fd.execute(inputs) + + +# Test that we properly handle packed type +def test_packed_fp4(nvfuser_direct_test): + t0 = torch.rand((32,), dtype=torch.float32, device="cuda:0") + # we'll just ignore the scaling factor, since we only want to test basic fp4 support + t0_fp4, _ = pytorch_nvfp4_quantize(t0, 1.0) + inputs = [t0_fp4] + + def fusion_func(fd: FusionDefinition): + T0 = fd.define_tensor( + shape=[-1], + contiguity=[True], + dtype=DataType.Float4_e2m1fn, + is_cpu=False, + stride_order=[0], + ) + T1 = fd.ops.relu(T0) + fd.add_output(T1) + + out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs) + nvfuser_direct_test.assertEqual( + unpack_fp4_bytes(out), unpack_fp4_bytes(t0_fp4).relu() + ) From c031cd7209ab7df84e98ebed8896912307bd874d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 13:43:48 -0700 Subject: [PATCH 2/8] fixing tests and adding API support on fp4 packed types --- python/python_direct/ir.cpp | 80 +++++++++++++++++---- tests/python/direct/test_python_frontend.py | 2 +- 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index efc8c56ced9..a857060ac4f 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -550,6 +550,30 @@ TensorView* defineTensor( return tv; } +// return the unpacked shape and dtype for a given packed dtype, where we need +// to double the size of the inner most dimension. +std::tuple, PrimDataType> translatePackedDtype( + const std::vector& shape, + const PrimDataType dtype, + const std::vector& stride_order) { + // TODO: switch to isPackedType when the pack width is retrieved through + // utility functions as well. + NVF_CHECK(dtype == DataType::Float4_e2m1fn_x2); + + int fastest_dim = -1; + for (const auto& [i, val] : enumerate(stride_order)) { + if (val == 0) { + fastest_dim = i; + break; + } + } + NVF_CHECK(fastest_dim >= 0, "illegal stride_order: ", stride_order); + + std::vector un_packed_shape = shape; + un_packed_shape[fastest_dim] *= 2; + return {un_packed_shape, DataType::Float4_e2m1fn}; +} + void bindDefineTensor(py::module& nvfuser) { nvfuser .def( @@ -560,7 +584,15 @@ void bindDefineTensor(py::module& nvfuser) { const bool is_cpu = false, const std::vector& stride_order = {}) -> TensorView* { verifyShape(shape); - return defineTensor(shape, contiguity, dtype, is_cpu, stride_order); + if (!isPackedType(dtype)) { + return defineTensor( + shape, contiguity, dtype, is_cpu, stride_order); + } else { + auto&& [new_shape, new_dtype] = + translatePackedDtype(shape, dtype, stride_order); + return defineTensor( + new_shape, contiguity, new_dtype, is_cpu, stride_order); + } }, py::arg("shape"), py::arg("contiguity"), @@ -577,12 +609,23 @@ void bindDefineTensor(py::module& nvfuser) { const bool is_cpu = false, const std::vector& stride_order = {}) -> TensorView* { verifyShape(shape); - return defineTensor( - shape, - getContiguityVec(shape, stride_order, contiguity), - dtype, - is_cpu, - stride_order); + if (!isPackedType(dtype)) { + return defineTensor( + shape, + getContiguityVec(shape, stride_order, contiguity), + dtype, + is_cpu, + stride_order); + } else { + auto&& [new_shape, new_dtype] = + translatePackedDtype(shape, dtype, stride_order); + return defineTensor( + new_shape, + getContiguityVec(new_shape, stride_order, contiguity), + new_dtype, + is_cpu, + stride_order); + } }, py::arg("shape"), py::arg("contiguity") = false, @@ -606,12 +649,23 @@ void bindDefineTensor(py::module& nvfuser) { std::vector stride_order; std::tie(contiguity, stride_order) = computeTensorDescriptor(sizes, strides); - return defineTensor( - getTensorViewBuilderSizes(sizes, static_sizes), - contiguity, - dtype, - is_cpu, - stride_order); + if (!isPackedType(dtype)) { + return defineTensor( + getTensorViewBuilderSizes(sizes, static_sizes), + contiguity, + dtype, + is_cpu, + stride_order); + } else { + auto&& [new_sizes, new_dtype] = + translatePackedDtype(sizes, dtype, stride_order); + return defineTensor( + getTensorViewBuilderSizes(new_sizes, static_sizes), + contiguity, + new_dtype, + is_cpu, + stride_order); + } }, py::arg("sizes"), py::arg("strides"), diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index f89788f5dd3..1e62bb43ca7 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2624,7 +2624,7 @@ def fusion_func(fd: FusionDefinition): # Test that we properly handle packed type def test_packed_fp4(nvfuser_direct_test): - t0 = torch.rand((32,), dtype=torch.float32, device="cuda:0") + t0 = torch.rand((8, 32,), dtype=torch.float32, device="cuda:0") # we'll just ignore the scaling factor, since we only want to test basic fp4 support t0_fp4, _ = pytorch_nvfp4_quantize(t0, 1.0) inputs = [t0_fp4] From 7eb6641ef625842d658ded36db05175afe60beab Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 13:47:07 -0700 Subject: [PATCH 3/8] adding tests --- tests/python/direct/test_python_frontend.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index 1e62bb43ca7..45a68a4235d 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2631,14 +2631,15 @@ def test_packed_fp4(nvfuser_direct_test): def fusion_func(fd: FusionDefinition): T0 = fd.define_tensor( - shape=[-1], - contiguity=[True], + shape=[-1, -1], + contiguity=[True, True], dtype=DataType.Float4_e2m1fn, is_cpu=False, - stride_order=[0], ) - T1 = fd.ops.relu(T0) - fd.add_output(T1) + T1 = fd.ops.cast(T0, DataType.Float) + T2 = fd.ops.relu(T1) + T3 = fd.ops.cast(T2, DataType.Float4_e2m1fn) + fd.add_output(T3) out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs) nvfuser_direct_test.assertEqual( From 4f4f9bde5fa4487c32e0ff6cb1d7dd957e8edaa7 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 14:27:15 -0700 Subject: [PATCH 4/8] fixing tests --- tests/python/direct/test_python_frontend.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index 45a68a4235d..b9f50054450 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2624,7 +2624,7 @@ def fusion_func(fd: FusionDefinition): # Test that we properly handle packed type def test_packed_fp4(nvfuser_direct_test): - t0 = torch.rand((8, 32,), dtype=torch.float32, device="cuda:0") + t0 = torch.rand((1024, 32,), dtype=torch.float32, device="cuda:0") # we'll just ignore the scaling factor, since we only want to test basic fp4 support t0_fp4, _ = pytorch_nvfp4_quantize(t0, 1.0) inputs = [t0_fp4] @@ -2638,10 +2638,8 @@ def fusion_func(fd: FusionDefinition): ) T1 = fd.ops.cast(T0, DataType.Float) T2 = fd.ops.relu(T1) - T3 = fd.ops.cast(T2, DataType.Float4_e2m1fn) - fd.add_output(T3) + fd.add_output(T2) out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs) - nvfuser_direct_test.assertEqual( - unpack_fp4_bytes(out), unpack_fp4_bytes(t0_fp4).relu() - ) + ref = unpack_fp4_bytes(t0_fp4, torch.float32).relu() + nvfuser_direct_test.assertEqual(out[0], ref) From d26ccde82e063d1fff98ce405aa4cf6c4dbeb21e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 14:47:33 -0700 Subject: [PATCH 5/8] fixing example --- tests/python/direct/test_python_frontend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index b9f50054450..9aa56091707 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2631,10 +2631,11 @@ def test_packed_fp4(nvfuser_direct_test): def fusion_func(fd: FusionDefinition): T0 = fd.define_tensor( - shape=[-1, -1], + shape=[1024, 16], contiguity=[True, True], - dtype=DataType.Float4_e2m1fn, + dtype=DataType.Float4_e2m1fn_x2, is_cpu=False, + stride_order=[1, 0], ) T1 = fd.ops.cast(T0, DataType.Float) T2 = fd.ops.relu(T1) From 57182a7df846e136aca22f3849dda7ed807e764c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 14:54:44 -0700 Subject: [PATCH 6/8] cleanup initialization on fastest dim --- python/python_direct/ir.cpp | 4 +--- tests/python/direct/test_python_frontend.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index a857060ac4f..671806ad391 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -560,15 +560,13 @@ std::tuple, PrimDataType> translatePackedDtype( // utility functions as well. NVF_CHECK(dtype == DataType::Float4_e2m1fn_x2); - int fastest_dim = -1; + int fastest_dim = shape.size() - 1; for (const auto& [i, val] : enumerate(stride_order)) { if (val == 0) { fastest_dim = i; break; } } - NVF_CHECK(fastest_dim >= 0, "illegal stride_order: ", stride_order); - std::vector un_packed_shape = shape; un_packed_shape[fastest_dim] *= 2; return {un_packed_shape, DataType::Float4_e2m1fn}; diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index 9aa56091707..4c276bbd328 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2635,7 +2635,6 @@ def fusion_func(fd: FusionDefinition): contiguity=[True, True], dtype=DataType.Float4_e2m1fn_x2, is_cpu=False, - stride_order=[1, 0], ) T1 = fd.ops.cast(T0, DataType.Float) T2 = fd.ops.relu(T1) From 2bf1edc85e45d8560f43cc2820c0c6f74c26b5e8 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 13 Oct 2025 14:56:29 -0700 Subject: [PATCH 7/8] Black --- tests/python/direct/test_python_frontend.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index 4c276bbd328..abfbab4ba6d 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2624,7 +2624,14 @@ def fusion_func(fd: FusionDefinition): # Test that we properly handle packed type def test_packed_fp4(nvfuser_direct_test): - t0 = torch.rand((1024, 32,), dtype=torch.float32, device="cuda:0") + t0 = torch.rand( + ( + 1024, + 32, + ), + dtype=torch.float32, + device="cuda:0", + ) # we'll just ignore the scaling factor, since we only want to test basic fp4 support t0_fp4, _ = pytorch_nvfp4_quantize(t0, 1.0) inputs = [t0_fp4] From af104af0db1a5d83a7173e2342f9233eec3c8b1b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 14 Oct 2025 00:59:01 -0700 Subject: [PATCH 8/8] filter test on pre_blackwell card --- tests/python/direct/test_python_frontend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index abfbab4ba6d..ef0400469e9 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -2623,6 +2623,9 @@ def fusion_func(fd: FusionDefinition): # Test that we properly handle packed type +@pytest.mark.skipif( + is_pre_blackwell(), reason="Only supported on blackwell and newer devices." +) def test_packed_fp4(nvfuser_direct_test): t0 = torch.rand( (