Skip to content

fp4 packed dtype support on direct python API#5380

Merged
jjsjann123 merged 12 commits intomainfrom
jj/nvfp4_direct_binding
Oct 17, 2025
Merged

fp4 packed dtype support on direct python API#5380
jjsjann123 merged 12 commits intomainfrom
jj/nvfp4_direct_binding

Conversation

@jjsjann123
Copy link
Collaborator

Cherry-picked from #5230

  • packed fp4 dtype needs to be supported by python API in order to support framework integration.

FusionDefinition is not expecting to have packed dtype. But since that's the only fp4 dtype supported by framework, our integration would still need to support it.

This PR adds a quick translation at FusionDefinition.define_tensor to translate packed dtype into unpacked dtype to keep the WAR transparent to integration/user.

@github-actions
Copy link

github-actions bot commented Oct 13, 2025

Review updated until commit 5a84246

Description

  • Add support for packed fp4 dtype in Python API

  • Translate packed dtype to unpacked dtype transparently

  • Handle shape adjustment for fastest dimension in packed format

  • Include tests for fp4 packed dtype on supported hardware


Changes walkthrough 📝

Relevant files
Enhancement
ir.cpp
Add packed fp4 dtype translation in define_tensor               

python/python_direct/ir.cpp

  • Added translatePackedDtype to convert packed fp4 shape and dtype
  • Doubles the fastest dimension size during unpacking
  • Integrated dtype translation into define_tensor bindings
  • Handles both explicit and implicit contiguity cases
  • +65/-13 
    Tests
    test_python_frontend.py
    Add test for packed fp4 dtype support                                       

    tests/python/direct/test_python_frontend.py

  • Imported fp4 quantization and unpacking utilities
  • Added test for packed fp4 tensor handling
  • Skips test on pre-Blackwell GPUs
  • Compares output against reference unpacked result
  • +38/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function translatePackedDtype assumes the input dtype is exactly Float4_e2m1fn_x2 via NVF_CHECK, but does not handle other potential packed dtypes. This may lead to failures if new packed types are introduced without updating this check.

    // return the unpacked shape and dtype for a given packed dtype, where we need
    // to double the size of the inner most dimension.
    std::tuple<std::vector<int64_t>, PrimDataType> translatePackedDtype(
        const std::vector<int64_t>& shape,
        const PrimDataType dtype,
        const std::vector<int64_t>& stride_order) {
      // TODO: switch to isPackedType when the pack width is retrieved through
      // utility functions as well.
      NVF_CHECK(dtype == DataType::Float4_e2m1fn_x2);
    
      int fastest_dim = shape.size() - 1;
      for (const auto& [i, val] : enumerate(stride_order)) {
        if (val == 0) {
          fastest_dim = i;
          break;
        }
      }
      std::vector<int64_t> un_packed_shape = shape;
      un_packed_shape[fastest_dim] *= 2;
      return {un_packed_shape, DataType::Float4_e2m1fn};
    }
    Performance Consideration

    The translation from packed to unpacked dtype doubles the innermost dimension, which could impact memory usage and access patterns. The performance implications of this expansion, especially on tensor layout and memory bandwidth, should be evaluated and documented.

    if (!isPackedType(dtype)) {
      return defineTensor(
          shape, contiguity, dtype, is_cpu, stride_order);
    } else {
      auto&& [new_shape, new_dtype] =
          translatePackedDtype(shape, dtype, stride_order);
      return defineTensor(
          new_shape, contiguity, new_dtype, is_cpu, stride_order);
    }
    Test Coverage

    The test test_packed_fp4 uses a fixed shape and stride order; additional test cases with non-contiguous tensors or different stride orders should be considered to ensure robustness of the packed dtype handling.

    # Test that we properly handle packed type
    @pytest.mark.skipif(
        is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
    )
    def test_packed_fp4(nvfuser_direct_test):
        t0 = torch.rand(
            (
                1024,
                32,
            ),
            dtype=torch.float32,
            device="cuda:0",
        )
        # we'll just ignore the scaling factor, since we only want to test basic fp4 support
        t0_fp4, _ = pytorch_nvfp4_quantize(t0, 1.0)
        inputs = [t0_fp4]
    
        def fusion_func(fd: FusionDefinition):
            T0 = fd.define_tensor(
                shape=[1024, 16],
                contiguity=[True, True],
                dtype=DataType.Float4_e2m1fn_x2,
                is_cpu=False,
            )
            T1 = fd.ops.cast(T0, DataType.Float)
            T2 = fd.ops.relu(T1)
            fd.add_output(T2)
    
        out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs)
        ref = unpack_fp4_bytes(t0_fp4, torch.float32).relu()
        nvfuser_direct_test.assertEqual(out[0], ref)

    @jjsjann123
    Copy link
    Collaborator Author

    cherry-picked from already stamped changes by @rdspring1 from #5230
    I figure it's easier to merge it like this.

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 requested a review from rdspring1 October 13, 2025 21:56
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 merged commit 95c857f into main Oct 17, 2025
    66 of 67 checks passed
    @jjsjann123 jjsjann123 deleted the jj/nvfp4_direct_binding branch October 17, 2025 15:20
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    Cherry-picked from #5230 
    
    * packed fp4 dtype needs to be supported by python API in order to
    support framework integration.
    
    FusionDefinition is not expecting to have packed dtype. But since that's
    the only fp4 dtype supported by framework, our integration would still
    need to support it.
    
    This PR adds a quick translation at `FusionDefinition.define_tensor` to
    translate packed dtype into unpacked dtype to keep the WAR transparent
    to integration/user.
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants