Skip to content

exposing layout op at direct python binding#5345

Merged
jjsjann123 merged 15 commits intomainfrom
jj/layout_op_PR6_direct_binding
Oct 21, 2025
Merged

exposing layout op at direct python binding#5345
jjsjann123 merged 15 commits intomainfrom
jj/layout_op_PR6_direct_binding

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Oct 8, 2025

Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding <-- this PR
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler

This PR

Expose layout op at python direct binding.
Added nvfp4 grouped gemm in python test.

Minor fixes:

  1. Added support of allocation domain for output of layout op in concretization pass to maintain the dependency on padded allocation domain to its logical domain. No longer needed, handled in Visit allIDs in IterVisitor #5384
  2. Skipped validation for setAllocationDomain
  3. updated reference implementation to match the math order in nvfuser decomposed nvfp4 quantization.

TODO:

python tests requires IdModel Indexer in order to work. See issue #5200, as well as suggested WAR in #5200 (comment)

@github-actions
Copy link

github-actions bot commented Oct 8, 2025

Review updated until commit ce77b62

Description

  • Exposed layout op preprocess_grouped_matmul_input_sf in Python direct binding

  • Added Python test for NVFP4 grouped GEMM with layout op

  • Added skip_validation flag to setAllocationDomain API

  • Updated reference implementation for NVFP4 quantization math order


Changes walkthrough 📝

Relevant files
Enhancement
5 files
nodes.cpp
Add skip_validation to setAllocationDomain                             
+6/-3     
ops.cpp
Expose preprocess_grouped_matmul_input_sf in Python           
+34/-0   
python_translate.cpp
Add Python translation for layout op                                         
+11/-0   
utils.py
Add set_env context manager                                                           
+15/-0   
internal_base_nodes.h
Update setAllocationDomain with skip_validation                   
+7/-3     
Tests
3 files
test_cutlass_nvfp4_gemm.py
Relax mismatch tolerance in test                                                 
+1/-1     
test_narrow_precision.py
Update grouped mm test inputs                                                       
+0/-4     
test_with_id_model_indexer.py
Add test for layout op and grouped mm                                       
+217/-0 
Bug fix
1 files
narrow_precision.py
Fix NVFP4 quantization math order                                               
+4/-4     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Validation Bypass

The addition of a skip_validation parameter in setAllocationDomain allows bypassing domain equivalence checks, which could lead to incorrect tensor domain configurations if used improperly. This should be carefully reviewed to ensure it's only used in safe contexts.

void TensorDomain::setAllocationDomain(
    std::vector<IterDomain*> new_allocation_domain,
    std::vector<std::optional<bool>> new_contiguity,
    bool skip_validation) {
  validateContiguity(new_allocation_domain, new_contiguity);

  if (!skip_validation) {
    ir_utils::validateDomainEquivalence(
        logical_domain_, new_allocation_domain, additional_ids_);
  }

  allocation_domain_ = std::move(new_allocation_domain);
  contiguity_ = std::move(new_contiguity);
}
Environment Mutation

The test uses environment variable manipulation via set_env context manager to work around indexing issues. This approach could have side effects on other tests and should be evaluated for safer alternatives.

with set_env(NVFUSER_ENABLE="id_model(all)"):
    o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
Precision Change

The quantization implementation has been modified to clamp before converting to float8, which changes the numerical behavior compared to the original implementation that converted to float8 first. This could affect numerical accuracy and should be validated.

scaled_block_scale_fp32 = torch.clamp(
    scaled_block_scale_fp32, min=FLOAT8_E4M3_EPS, max=FLOAT8_E4M3_MAX
)
scaled_block_scale_fp8 = scaled_block_scale_fp32.to(torch.float8_e4m3fn)
total_scale = scaled_block_scale_fp32 / a_global_scale
a_scaled = a_fp32 / total_scale.unsqueeze(-1)

@jjsjann123 jjsjann123 changed the title Jj/layout op pr6 direct binding exposing layout op at direct python binding Oct 8, 2025
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from e749d3f to 17dd08a Compare October 8, 2025 18:14
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from 4c2e30a to c342568 Compare October 8, 2025 18:18
@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 marked this pull request as ready for review October 8, 2025 20:54
logical_dom, layout_op->g(), layout_op->layout());
// skip validation because allocation domain doesn't converge to logical
// domain.
out_tv->domain()->setAllocationDomain(alloc_dom, true, true);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tagging @jacobhinkle since I'm touching concretization.

jjsjann123 added a commit that referenced this pull request Oct 9, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler  <-- this PR

## This PR

Allow scheduler to take `PreprocessGroupedMatmulInputSf` as a pointwise
operation using the runtime function.

The main code change is to addressing the assumption of the runtime
function:

- [x] add segmentation for offsets to ensure they are in global memory.
* Existing assumption is that two offsets inputs and output of the
layout op would be in global memory, where the runtime function could
read the entirety of both offsets and write the output via data
dependent indexing. This allows the operation to be treated as a trivial
pointwise-op.
    * avoids caching layout op outputs or offsets inputs.
* avoids putting layout op output into persistent buffers (since we
require write to global memory).

- [x] detect unsafe consumption of PreprocessGroupedMatmulInputSf output
in `fusion_segmenter.cpp`
- [x] relax asserts on assumption that there's always a legit path
between loop->allocation and logical->allocation in some scheduler
utils.


TODOs for future PR:

* end-2-end python test with direct binding.
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from 17dd08a to f669051 Compare October 10, 2025 00:44
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from bc84003 to 210bbe5 Compare October 10, 2025 00:44
std::vector<IterDomain*> new_allocation_domain,
std::vector<std::optional<bool>> new_contiguity) {
std::vector<std::optional<bool>> new_contiguity,
bool skip_validation) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we still need this, since we already marked allocation domain as symbolic now.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from de4645d to b4a3a2f Compare October 14, 2025 09:04
@jjsjann123
Copy link
Collaborator Author

!test


# FIXME: force indexing to use IdModel indexer to avoid indexing error.
# see issue: https://github.com/NVIDIA/Fuser/issues/5200
with set_env(NVFUSER_ENABLE="id_model(all)"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't seem to work when we run multiple tests together, since we cache the env_variable. I need to change this to the scope of this file maybe... or just skip this test.

jacobhinkle added a commit that referenced this pull request Oct 14, 2025
Previously, we only traversed all producers of the loop domain in
IterVisitor::traverseBetween. That is a problem in cases where we
schedule like producer of a reshape, or in exotic cases like #5345 where
the domains are disconnected.

This PR ensures that we traverse every ID in the TensorDomain regardless
of the relations between the domains contained within. Note that it calls
TensorDomain::allIDs when getting the "next" statements, which will do a
redundant topological sort.
jacobhinkle added a commit that referenced this pull request Oct 16, 2025
This is split off from #5345, so I don't have a specific repro for any
incorrect behavior that this fixes.

Previously, we only traversed all producers of the loop domain in
IterVisitor::traverseBetween. That is a problem in cases where we
schedule like producer of a reshape, or in exotic cases like #5345 where
the domains are disconnected.

This PR ensures that we traverse every ID in the TensorDomain regardless
of the relations between the domains contained within.

---------

Co-authored-by: jjsjann123 <jiej@nvidia.com>
jjsjann123 added a commit that referenced this pull request Oct 16, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op  <-- this PR
#5174 allow layout op in automatic scheduler

## This PR

This is a tiny refactor to only expect the two `offsets` to have the
size equal to num_groups. The reason is that our cutlass kernel were
expecting that in the first place and I didn't match it right in the
first time.

e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
Previously the offsets would be [0, 2, 5, 10]; after the refactor, the
offsets would be [0, 2, 5].
Base automatically changed from jj/layout_op_PR5_python_direct_binding to main October 16, 2025 20:11
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from b4a3a2f to fe8da62 Compare October 16, 2025 20:23
@jjsjann123
Copy link
Collaborator Author

Had to rebase & force push to avoid resolving all conflicts from the base change.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

everything looks good, except the env var thing that's not working properly. I might change that to a standalone python test, just so it's not going to mess with the other tests.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 merged commit 7ee0a94 into main Oct 21, 2025
67 checks passed
@jjsjann123 jjsjann123 deleted the jj/layout_op_PR6_direct_binding branch October 21, 2025 04:34
tbqh pushed a commit that referenced this pull request Nov 12, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler  <-- this PR

## This PR

Allow scheduler to take `PreprocessGroupedMatmulInputSf` as a pointwise
operation using the runtime function.

The main code change is to addressing the assumption of the runtime
function:

- [x] add segmentation for offsets to ensure they are in global memory.
* Existing assumption is that two offsets inputs and output of the
layout op would be in global memory, where the runtime function could
read the entirety of both offsets and write the output via data
dependent indexing. This allows the operation to be treated as a trivial
pointwise-op.
    * avoids caching layout op outputs or offsets inputs.
* avoids putting layout op output into persistent buffers (since we
require write to global memory).

- [x] detect unsafe consumption of PreprocessGroupedMatmulInputSf output
in `fusion_segmenter.cpp`
- [x] relax asserts on assumption that there's always a legit path
between loop->allocation and logical->allocation in some scheduler
utils.


TODOs for future PR:

* end-2-end python test with direct binding.
tbqh pushed a commit that referenced this pull request Nov 12, 2025
This is split off from #5345, so I don't have a specific repro for any
incorrect behavior that this fixes.

Previously, we only traversed all producers of the loop domain in
IterVisitor::traverseBetween. That is a problem in cases where we
schedule like producer of a reshape, or in exotic cases like #5345 where
the domains are disconnected.

This PR ensures that we traverse every ID in the TensorDomain regardless
of the relations between the domains contained within.

---------

Co-authored-by: jjsjann123 <jiej@nvidia.com>
tbqh pushed a commit that referenced this pull request Nov 12, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op  <-- this PR
#5174 allow layout op in automatic scheduler

## This PR

This is a tiny refactor to only expect the two `offsets` to have the
size equal to num_groups. The reason is that our cutlass kernel were
expecting that in the first place and I didn't match it right in the
first time.

e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
Previously the offsets would be [0, 2, 5, 10]; after the refactor, the
offsets would be [0, 2, 5].
tbqh pushed a commit that referenced this pull request Nov 12, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding  <-- this PR
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler

## This PR

Expose layout op at python direct binding.
Added nvfp4 grouped gemm in python test.

Minor fixes:

1. ~Added support of allocation domain for output of layout op in
concretization pass to maintain the dependency on padded allocation
domain to its logical domain.~ No longer needed, handled in
#5384
2. Skipped validation for `setAllocationDomain`
3. updated reference implementation to match the math order in nvfuser
decomposed nvfp4 quantization.

TODO:

python tests requires IdModel Indexer in order to work. See issue #5200,
as well as suggested WAR in
#5200 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants