Skip to content
78 changes: 65 additions & 13 deletions python/python_direct/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,28 @@ TensorView* defineTensor(
return tv;
}

// return the unpacked shape and dtype for a given packed dtype, where we need
// to double the size of the inner most dimension.
std::tuple<std::vector<int64_t>, PrimDataType> translatePackedDtype(
const std::vector<int64_t>& shape,
const PrimDataType dtype,
const std::vector<int64_t>& stride_order) {
// TODO: switch to isPackedType when the pack width is retrieved through
// utility functions as well.
NVF_CHECK(dtype == DataType::Float4_e2m1fn_x2);

int fastest_dim = shape.size() - 1;
for (const auto& [i, val] : enumerate(stride_order)) {
if (val == 0) {
fastest_dim = i;
break;
}
}
std::vector<int64_t> un_packed_shape = shape;
un_packed_shape[fastest_dim] *= 2;
return {un_packed_shape, DataType::Float4_e2m1fn};
}

void bindDefineTensor(py::module& nvfuser) {
nvfuser
.def(
Expand All @@ -560,7 +582,15 @@ void bindDefineTensor(py::module& nvfuser) {
const bool is_cpu = false,
const std::vector<int64_t>& stride_order = {}) -> TensorView* {
verifyShape(shape);
return defineTensor(shape, contiguity, dtype, is_cpu, stride_order);
if (!isPackedType(dtype)) {
return defineTensor(
shape, contiguity, dtype, is_cpu, stride_order);
} else {
auto&& [new_shape, new_dtype] =
translatePackedDtype(shape, dtype, stride_order);
return defineTensor(
new_shape, contiguity, new_dtype, is_cpu, stride_order);
}
},
py::arg("shape"),
py::arg("contiguity"),
Expand All @@ -577,12 +607,23 @@ void bindDefineTensor(py::module& nvfuser) {
const bool is_cpu = false,
const std::vector<int64_t>& stride_order = {}) -> TensorView* {
verifyShape(shape);
return defineTensor(
shape,
getContiguityVec(shape, stride_order, contiguity),
dtype,
is_cpu,
stride_order);
if (!isPackedType(dtype)) {
return defineTensor(
shape,
getContiguityVec(shape, stride_order, contiguity),
dtype,
is_cpu,
stride_order);
} else {
auto&& [new_shape, new_dtype] =
translatePackedDtype(shape, dtype, stride_order);
return defineTensor(
new_shape,
getContiguityVec(new_shape, stride_order, contiguity),
new_dtype,
is_cpu,
stride_order);
}
},
py::arg("shape"),
py::arg("contiguity") = false,
Expand All @@ -606,12 +647,23 @@ void bindDefineTensor(py::module& nvfuser) {
std::vector<int64_t> stride_order;
std::tie(contiguity, stride_order) =
computeTensorDescriptor(sizes, strides);
return defineTensor(
getTensorViewBuilderSizes(sizes, static_sizes),
contiguity,
dtype,
is_cpu,
stride_order);
if (!isPackedType(dtype)) {
return defineTensor(
getTensorViewBuilderSizes(sizes, static_sizes),
contiguity,
dtype,
is_cpu,
stride_order);
} else {
auto&& [new_sizes, new_dtype] =
translatePackedDtype(sizes, dtype, stride_order);
return defineTensor(
getTensorViewBuilderSizes(new_sizes, static_sizes),
contiguity,
new_dtype,
is_cpu,
stride_order);
}
},
py::arg("sizes"),
py::arg("strides"),
Expand Down
38 changes: 38 additions & 0 deletions tests/python/direct/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
verify_stride_order,
)

from python.direct_utils.narrow_precision import (
pytorch_nvfp4_quantize,
unpack_fp4_bytes,
)


def test_basic(nvfuser_direct_test):
inputs = [
Expand Down Expand Up @@ -2615,3 +2620,36 @@ def fusion_func(fd: FusionDefinition):
RuntimeError, match="KernelExecutor does not support the Fusion provided."
):
_ = fd.execute(inputs)


# Test that we properly handle packed type
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
def test_packed_fp4(nvfuser_direct_test):
t0 = torch.rand(
(
1024,
32,
),
dtype=torch.float32,
device="cuda:0",
)
# we'll just ignore the scaling factor, since we only want to test basic fp4 support
t0_fp4, _ = pytorch_nvfp4_quantize(t0, 1.0)
inputs = [t0_fp4]

def fusion_func(fd: FusionDefinition):
T0 = fd.define_tensor(
shape=[1024, 16],
contiguity=[True, True],
dtype=DataType.Float4_e2m1fn_x2,
is_cpu=False,
)
T1 = fd.ops.cast(T0, DataType.Float)
T2 = fd.ops.relu(T1)
fd.add_output(T2)

out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs)
ref = unpack_fp4_bytes(t0_fp4, torch.float32).relu()
nvfuser_direct_test.assertEqual(out[0], ref)
Loading