diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index efc8c56ced9..671806ad391 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -550,6 +550,28 @@ TensorView* defineTensor( return tv; } +// return the unpacked shape and dtype for a given packed dtype, where we need +// to double the size of the inner most dimension. +std::tuple, PrimDataType> translatePackedDtype( + const std::vector& shape, + const PrimDataType dtype, + const std::vector& stride_order) { + // TODO: switch to isPackedType when the pack width is retrieved through + // utility functions as well. + NVF_CHECK(dtype == DataType::Float4_e2m1fn_x2); + + int fastest_dim = shape.size() - 1; + for (const auto& [i, val] : enumerate(stride_order)) { + if (val == 0) { + fastest_dim = i; + break; + } + } + std::vector un_packed_shape = shape; + un_packed_shape[fastest_dim] *= 2; + return {un_packed_shape, DataType::Float4_e2m1fn}; +} + void bindDefineTensor(py::module& nvfuser) { nvfuser .def( @@ -560,7 +582,15 @@ void bindDefineTensor(py::module& nvfuser) { const bool is_cpu = false, const std::vector& stride_order = {}) -> TensorView* { verifyShape(shape); - return defineTensor(shape, contiguity, dtype, is_cpu, stride_order); + if (!isPackedType(dtype)) { + return defineTensor( + shape, contiguity, dtype, is_cpu, stride_order); + } else { + auto&& [new_shape, new_dtype] = + translatePackedDtype(shape, dtype, stride_order); + return defineTensor( + new_shape, contiguity, new_dtype, is_cpu, stride_order); + } }, py::arg("shape"), py::arg("contiguity"), @@ -577,12 +607,23 @@ void bindDefineTensor(py::module& nvfuser) { const bool is_cpu = false, const std::vector& stride_order = {}) -> TensorView* { verifyShape(shape); - return defineTensor( - shape, - getContiguityVec(shape, stride_order, contiguity), - dtype, - is_cpu, - stride_order); + if (!isPackedType(dtype)) { + return defineTensor( + shape, + getContiguityVec(shape, stride_order, contiguity), + dtype, + is_cpu, + stride_order); + } else { + auto&& [new_shape, new_dtype] = + translatePackedDtype(shape, dtype, stride_order); + return defineTensor( + new_shape, + getContiguityVec(new_shape, stride_order, contiguity), + new_dtype, + is_cpu, + stride_order); + } }, py::arg("shape"), py::arg("contiguity") = false, @@ -606,12 +647,23 @@ void bindDefineTensor(py::module& nvfuser) { std::vector stride_order; std::tie(contiguity, stride_order) = computeTensorDescriptor(sizes, strides); - return defineTensor( - getTensorViewBuilderSizes(sizes, static_sizes), - contiguity, - dtype, - is_cpu, - stride_order); + if (!isPackedType(dtype)) { + return defineTensor( + getTensorViewBuilderSizes(sizes, static_sizes), + contiguity, + dtype, + is_cpu, + stride_order); + } else { + auto&& [new_sizes, new_dtype] = + translatePackedDtype(sizes, dtype, stride_order); + return defineTensor( + getTensorViewBuilderSizes(new_sizes, static_sizes), + contiguity, + new_dtype, + is_cpu, + stride_order); + } }, py::arg("sizes"), py::arg("strides"), diff --git a/tests/python/direct/test_python_frontend.py b/tests/python/direct/test_python_frontend.py index c6ddafce8bc..ef0400469e9 100644 --- a/tests/python/direct/test_python_frontend.py +++ b/tests/python/direct/test_python_frontend.py @@ -23,6 +23,11 @@ verify_stride_order, ) +from python.direct_utils.narrow_precision import ( + pytorch_nvfp4_quantize, + unpack_fp4_bytes, +) + def test_basic(nvfuser_direct_test): inputs = [ @@ -2615,3 +2620,36 @@ def fusion_func(fd: FusionDefinition): RuntimeError, match="KernelExecutor does not support the Fusion provided." ): _ = fd.execute(inputs) + + +# Test that we properly handle packed type +@pytest.mark.skipif( + is_pre_blackwell(), reason="Only supported on blackwell and newer devices." +) +def test_packed_fp4(nvfuser_direct_test): + t0 = torch.rand( + ( + 1024, + 32, + ), + dtype=torch.float32, + device="cuda:0", + ) + # we'll just ignore the scaling factor, since we only want to test basic fp4 support + t0_fp4, _ = pytorch_nvfp4_quantize(t0, 1.0) + inputs = [t0_fp4] + + def fusion_func(fd: FusionDefinition): + T0 = fd.define_tensor( + shape=[1024, 16], + contiguity=[True, True], + dtype=DataType.Float4_e2m1fn_x2, + is_cpu=False, + ) + T1 = fd.ops.cast(T0, DataType.Float) + T2 = fd.ops.relu(T1) + fd.add_output(T2) + + out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs) + ref = unpack_fp4_bytes(t0_fp4, torch.float32).relu() + nvfuser_direct_test.assertEqual(out[0], ref)