Skip to content

Python bindings for block quantization #5591

Merged
protonu merged 11 commits intomainfrom
pbasu_bq_py
Dec 2, 2025
Merged

Python bindings for block quantization #5591
protonu merged 11 commits intomainfrom
pbasu_bq_py

Conversation

@protonu
Copy link
Collaborator

@protonu protonu commented Nov 25, 2025

This add Python bindings to the block quantization operator.
We add tests in Python to test block quantization to nvfp4. This tests against the rowwise 1D quantization in Transfomer Engine.

  • To get improved precision, the numerics in the runtime function was modified and the clamping functions removed.
  • We move the code to swizzle the block scales to an utility function.
  • We copied some code to dequantize nvfp4 to fp32 from TE. This is only used is test validation.

@protonu
Copy link
Collaborator Author

protonu commented Nov 25, 2025

!test

@github-actions
Copy link

github-actions bot commented Nov 25, 2025

Review updated until commit 9181f95

Description

  • Add Python bindings for block quantization operator with nv_block_quantize function

  • Implement swizzle block scales functionality as utility function for improved precision

  • Add validation to ensure swizzled scales when block scaling factor has allocation domain

  • Update CUDA kernel with improved numerics (multiplication instead of division, removed clamping)

  • Add comprehensive Python tests comparing against Transformer Engine NVFP4 quantization

Changes walkthrough

Relevant files
Enhancement
10 files
ops.cpp
Add Python binding for block quantization operation           
+42/-0   
arith.cpp
Update blockQuantize function with swizzle parameter         
+8/-1     
utils.cpp
Implement swizzleBlockScales utility function                       
+24/-0   
internal_nodes.cpp
Add swizzled_scales parameter to BlockQuantizationOp         
+3/-1     
block_quantization_kernels.cu
Update CUDA kernel with improved numerics                               
+11/-15 
python_translate.cpp
Add BlockQuantizationOp translation to Python                       
+25/-0   
narrow_precision.py
Add FP4 dequantization and swizzling utility functions     
+48/-0   
internal_nodes.h
Update BlockQuantizationOp header with swizzled_scales     
+6/-1     
utils.h
Add swizzleBlockScales function declaration                           
+5/-0     
arith.h
Update blockQuantize function signature                                   
+1/-0     
Bug fix
1 files
validation.cpp
Add validation for swizzled scales requirement                     
+5/-0     
Tests
2 files
test_narrow_precision.py
Add Python tests for block quantization vs Transformer Engine
+95/-0   
test_low_precision_recipe.cpp
Update C++ tests to use swizzle utility function                 
+79/-57 

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Restrictive validation logic

The new validation check requires swizzled scales whenever a block scaling factor has an allocation domain. This might be overly restrictive and could prevent valid use cases where users want allocation domains without swizzling. Consider if this constraint is necessary or if it should be made more flexible.

NVF_ERROR_EQ(
    bqop->isSwizzledScales(),
    true,
    "Block scaling factor with allocation domain requires swizzled "
    "scales.");
Rigid domain dimension check

The swizzleBlockScales function enforces a strict 2D loop domain requirement. This limitation might exclude valid tensor configurations that could benefit from block scaling swizzling. Consider making this validation more flexible to handle different tensor dimensionalities.

NVF_ERROR(
    tv && tv->getLoopDomain().size() == 2,
    "we can only swizzle 2D block scales tvs");
Incomplete error handling in TE comparison test

The test_nv_block_quantization_vs_te function catches exceptions during TE quantization but only prints an error message and returns None. The test continues execution and will likely fail with a cryptic error later. Consider adding explicit pytest.skip or pytest.xfail for cases where TE is not available, and ensure the test provides clear feedback about missing dependencies.

except Exception as e:
    print(f"\nError during quantization: {e}")
    import traceback

    traceback.print_exc()
    print("NOTE: This requires an NVIDIA Blackwell GPU and TE >= 1.6.")
    return None

@protonu protonu requested a review from jjsjann123 November 25, 2025 21:47
@protonu protonu marked this pull request as ready for review November 25, 2025 21:47
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 25, 2025

Greptile Overview

Greptile Summary

This PR adds Python bindings for the block quantization operator, enabling Python users to quantize tensors to NVFP4 format with block-wise scaling. The implementation aligns with Transformer Engine's rowwise 1D quantization approach.

Key Changes:

  • Added fd.ops.nv_block_quantize() Python API that returns (quantized_tensor, block_scales)
  • Updated quantization numerics to match TE behavior: replaced division with multiplication, removed manual clamping (relying on FP8 conversion for clamping), changed formula from global_scale / scaled_max to multiplication-based approach
  • Extracted block scale swizzling logic into reusable ir_utils::swizzleBlockScales() utility function
  • Added swizzle_scales parameter throughout the stack (Python bindings → C++ API → IR nodes → runtime)
  • Implemented comprehensive Python tests validating against TE's NVFP4 quantization with <10% mismatch tolerance
  • Added validation ensuring block scales with allocation domain require swizzled_scales=true

Numerics Changes:
The runtime kernel now computes scaled_max = block_max * global_scale * (1/6) then converts to FP8 and back, computing the reciprocal as global_scale / scaled_max. This differs from the old approach which divided then clamped explicitly. The FP8 conversion now handles clamping implicitly.

Confidence Score: 4/5

  • This PR is safe to merge with minor risk - the numerics changes align with Transformer Engine's proven approach and are well-tested
  • Score reflects well-structured implementation with proper testing, but numerics changes to remove clamping require careful validation in production. The division-by-zero concern raised in previous threads was confirmed as matching TE behavior (developer confirmed this copies TE). Code is well-organized with good separation of concerns (utility functions, proper parameter threading). Python bindings are correctly implemented with proper docstrings and return value handling.
  • Pay close attention to runtime/block_quantization_kernels.cu - the numerics changes remove explicit clamping and change the scaling formula, which could affect edge cases with all-zero blocks or extreme values

Important Files Changed

File Analysis

Filename Score Overview
python/python_direct/ops.cpp 4/5 Added Python bindings for nv_block_quantize, exposing block quantization to Python with proper parameter handling
runtime/block_quantization_kernels.cu 4/5 Updated quantization numerics to match Transformer Engine - removed clamping, changed division to multiplication, improved precision
csrc/ops/arith.cpp 5/5 Added swizzle_scales parameter support to blockQuantize function with proper delegation to utility function
csrc/ir/utils.cpp 5/5 Extracted block scale swizzling logic into reusable swizzleBlockScales utility function
tests/python/direct/test_narrow_precision.py 4/5 Added comprehensive test comparing nvfuser block quantization against Transformer Engine reference implementation

Sequence Diagram

sequenceDiagram
    participant User as Python User
    participant Binding as ops.cpp (Python Binding)
    participant API as arith.cpp (blockQuantize)
    participant Utils as ir_utils (swizzleBlockScales)
    participant IR as BlockQuantizationOp
    participant Runtime as block_quantization_kernels.cu
    
    User->>Binding: fd.ops.nv_block_quantize(input, global_scale, swizzle_scales, block_size)
    Binding->>API: blockQuantize(input, global_scale, block_size, swizzle_scales)
    API->>API: Create quantized_tensor and block_scales TensorViews
    
    alt swizzle_scales == true
        API->>Utils: swizzleBlockScales(block_scales)
        Utils->>Utils: Apply split/merge operations for 128x4 swizzle pattern
    end
    
    API->>IR: Create BlockQuantizationOp with swizzle flag
    IR->>Runtime: Execute block_quantize_to_nvfp4 kernel
    
    Runtime->>Runtime: Compute block_max across threads
    Runtime->>Runtime: scaled_max = block_max * global_scale * (1/6)
    Runtime->>Runtime: Convert to FP8: __float2e4m3(scaled_max)
    Runtime->>Runtime: Convert back to FP32: __e4m32float(clamped_max_fp8)
    Runtime->>Runtime: Compute reciprocal: global_scale / scaled_max
    Runtime->>Runtime: Scale values: vec_in[i] * scaled_max
    Runtime->>Runtime: Quantize to FP4: __float2e2m1(scaled_vals)
    
    Runtime-->>IR: Return quantized_tensor and block_scales
    IR-->>API: Return BlockQuantizationResults
    API-->>Binding: Return (quantized_tensor, block_scales)
    Binding-->>User: Return py::tuple(quantized_tensor, block_scales)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile


auto dtype = bqop->quantizedOutput()->as<TensorView>()->dtype();
// If the block scales tensor has an allocation, then it had
// to have been swizzled.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might not be the case for sharded tensor though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, I think we have some assumptions on the status of Fusion that we try to print out. Wondering if @rdspring1 has some input on this one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding a isSwizzled field to BlockQuantizationOp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this as well.
One thing that worried me if would we need to ensure that the field and allocation domain of the TV is always in sync? @jjsjann123, let me know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure that the field and allocation domain of the TV is always in sync?

If we have an isSwizzled field, I think we would only need it for replay. i.e. any further modification done during scheduling can be safely ignored.

I think the only messy bits are for sharded TVs, but we don't yet support sharding/scheduling in replay, so maybe leaving a comment on that would be fine for now.

@protonu protonu requested a review from rdspring1 November 26, 2025 02:09
Co-authored-by: jjsjann123 <jiej@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@protonu protonu requested a review from jjsjann123 November 26, 2025 02:44
@protonu
Copy link
Collaborator Author

protonu commented Nov 26, 2025

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@protonu
Copy link
Collaborator Author

protonu commented Dec 1, 2025

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@protonu
Copy link
Collaborator Author

protonu commented Dec 1, 2025

!test

@protonu
Copy link
Collaborator Author

protonu commented Dec 1, 2025

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@protonu
Copy link
Collaborator Author

protonu commented Dec 2, 2025

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with the updated field on block quantized op. Stamping to unblock.

bqop->isSwizzledScales(),
true,
"Block scaling factor with allocation domain requires swizzled "
"scales.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: for multi-gpu, where allocation domain would be used for sharding, we might run into issues here. But we can fix that at a later point.

@protonu protonu merged commit 9bd8c34 into main Dec 2, 2025
62 checks passed
@protonu protonu deleted the pbasu_bq_py branch December 2, 2025 17:15
protonu added a commit that referenced this pull request Dec 2, 2025
Stacked on top of #5591.

This removes old validation code in favor of new dequantization of nvfp4
that was added in above mentioned PR.
No tests needed.

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: jjsjann123 <jiej@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants