Skip to content

Support layout op in scheduler#5174

Merged
jjsjann123 merged 15 commits intomainfrom
jj/layout_op_PR4_scheduler_support
Oct 9, 2025
Merged

Support layout op in scheduler#5174
jjsjann123 merged 15 commits intomainfrom
jj/layout_op_PR4_scheduler_support

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 17, 2025

Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler <-- this PR

This PR

Allow scheduler to take PreprocessGroupedMatmulInputSf as a pointwise operation using the runtime function.

The main code change is to addressing the assumption of the runtime function:

  • add segmentation for offsets to ensure they are in global memory.

    • Existing assumption is that two offsets inputs and output of the layout op would be in global memory, where the runtime function could read the entirety of both offsets and write the output via data dependent indexing. This allows the operation to be treated as a trivial pointwise-op.
    • avoids caching layout op outputs or offsets inputs.
    • avoids putting layout op output into persistent buffers (since we require write to global memory).
  • detect unsafe consumption of PreprocessGroupedMatmulInputSf output in fusion_segmenter.cpp

  • relax asserts on assumption that there's always a legit path between loop->allocation and logical->allocation in some scheduler utils.

TODOs for future PR:

  • end-2-end python test with direct binding.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 changed the title Jj/layout op pr4 scheduler support Support layout op in scheduler Sep 17, 2025
@github-actions
Copy link

github-actions bot commented Sep 17, 2025

Review updated until commit fab8823

Description

  • Support PreprocessGroupedMatmulInputSf in automatic scheduler

  • Enforce global memory for layout op inputs/outputs

  • Validate safe consumption of layout op outputs

  • Relax domain mapping assumptions in scheduler utilities


Changes walkthrough 📝

Relevant files
Enhancement
7 files
fusion_segmenter.cpp
Add layout op safety checks and domain handling                   
+57/-10 
indexing.cpp
Avoid resize op in layout padding logic                                   
+10/-0   
registry.cpp
Remove layout op from unsupported ops list                             
+8/-2     
registry_utils.cpp
Add global buffer checks for layout op                                     
+46/-4   
domain_map.cpp
Skip offset TVs in domain mapping                                               
+8/-0     
utils.cpp
Exclude layout op from persistent buffer candidates           
+39/-17 
registry_utils.h
Declare global buffer requirement checker                               
+11/-0   
Bug fix
1 files
allocation_utils.cpp
Allow null layout when logical-allocation mismatch             
+7/-6     
Tests
1 files
test_layout_op.cpp
Add scheduler integration tests for layout op                       
+192/-0 

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Safety Check

The safety check for uses of PreprocessGroupedMatmulInputSf output only allows consumption by CutlassNvfp4GroupedMmaOp and specifically excludes it being used as scale1 or scale2. However, the condition in the loop seems to be checking the opposite: it fails if the input is NOT scale1 and NOT scale2 and is the tv_ptr, which might be incorrect. This logic should be reviewed to ensure it properly enforces safe usage.

      for (Expr* use : tv_ptr->uses()) {
        // clangtidy's false negative static analysis complains about use, see:
        // https://github.com/llvm/llvm-project/issues/134454#issuecomment-2816262570
        // However, the assert trick didn't seem to work here.
#if defined(__clang__)
        [[clang::suppress]] {
#endif
          auto* layout_op = dynamic_cast<CutlassNvfp4GroupedMmaOp*>(use);
          NVF_ERROR(
              layout_op,
              "use of output from PreprocessGroupedMatmulInputSf is unsafe by "
              "operation:",
              use->toString());
          NVF_ERROR(
              std::none_of(
                  layout_op->inputs().begin(),
                  layout_op->inputs().end(),
                  [&](const Val* input) {
                    // we can only use output from
                    // PreprocessGroupedMatmulInputSf as block scaling factor
                    return layout_op->scale1() != input &&
                        layout_op->scale2() != input && input == tv_ptr;
                  }

                  ),
              "use of output from PreprocessGroupedMatmulInputSf is unsafe by "
              "operation:",
              use->toString(),
              " as argument: ",
              tv_ptr->toString());
#if defined(__clang__)
        }
#endif
      }
Buffer Requirement

The function rejectScheduleFusionGlobalBufferRequirement checks that the output and both offset inputs of PreprocessGroupedMatmulInputSf are fusion outputs without consumers and fusion inputs respectively. However, for the output, being a fusion output without consumers may be too restrictive since the output is expected to be consumed by CutlassNvfp4GroupedMmaOp within the same fusion. This requirement might incorrectly reject valid fusions.

bool SchedulerTopologyChecker::rejectScheduleFusionGlobalBufferRequirement(
    Fusion* fusion,
    SchedulerType scheduler_type) {
  for (auto expr : fusion->exprs()) {
    if (expr->isA<PreprocessGroupedMatmulInputSf>()) {
      // The runtime function of layout_op needs:
      //   1. Write output directly to global memory
      //   2. Read two offset inputs directly from global memory
      auto layout_op = expr->as<PreprocessGroupedMatmulInputSf>();
      if (rejectScheduleFusionOutputRequirement(
              layout_op, layout_op->out(), scheduler_type) ||
          rejectScheduleFusionInputRequirement(
              layout_op, layout_op->inputOffsets(), scheduler_type) ||
          rejectScheduleFusionInputRequirement(
              layout_op, layout_op->outputOffsets(), scheduler_type)) {
        return true;
      }
    }
  }
  return false;
}
Persistent Buffer

The comment suggests that skipping PreprocessGroupedMatmulInputSf in persistent buffer analysis is a workaround (WAR) due to an assert in isCacheableUnmappableTv. This indicates a potential underlying issue in the persistent buffer analysis logic that should be addressed rather than worked around, as it might mask other problems.

// Adding PreprocessGroupedMatmulInputSf op to the list to skip it from
// being considered as candidate for persistent buffer. Otherwise, the
// lack of mapping between all producers to consumer triggers an assert in
// the check later inside `isCacheableUnmappableTv`. This feels like a
// reasonable WAR, since producer of indexing ops have been excluded from
// persistent_buffer candidates.
if (consumer->definition()
        ->isOneOf<
            SelectOp,
            IndexSelectOp,
            GatherOp,
            PreprocessGroupedMatmulInputSf>()) {
  continue;

}

// check PreprocessGroupedMatmulInputSf's output is in global memory by
// forcing it as a fusion segment output
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam I think I should have done a memory promotion instead of forcing a segment. But since the consumer is in a separate kernel anyway, I don't think the two alternative has any meaningful difference.

scheduler_debug_utils::canScheduleRejectReason(
scheduler_type, "Fusion has consumer of non indexable ops.");
return false;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for myself. I think I need to do the same for offsets. promoting memory type to shared/global so each thread can access the entire offsets in case they are computed within the kernel.

I'm not sure if sync analysis during lowering would have done that for me automatically.

@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 818c574 to 567d2ac Compare September 18, 2025 21:44
@jjsjann123 jjsjann123 changed the base branch from jj/layout_op_PR3_allocation_patch to jj/allocation_for_layout_op_PR_1 September 18, 2025 21:45
@jjsjann123 jjsjann123 force-pushed the jj/allocation_for_layout_op_PR_1 branch from 33d0ce3 to 17df15a Compare September 19, 2025 22:55
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 567d2ac to 34b394d Compare September 19, 2025 22:58
@jjsjann123 jjsjann123 force-pushed the jj/allocation_for_layout_op_PR_1 branch from 17df15a to f9acfc3 Compare September 22, 2025 21:50
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 34b394d to 3f56fcc Compare September 22, 2025 21:54
@jjsjann123 jjsjann123 force-pushed the jj/allocation_for_layout_op_PR_1 branch from f9acfc3 to 87afb60 Compare September 23, 2025 17:39
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from c6ea939 to 64a9f9e Compare September 23, 2025 17:41
@jjsjann123 jjsjann123 force-pushed the jj/allocation_for_layout_op_PR_1 branch 2 times, most recently from c64d299 to ea3cd68 Compare October 1, 2025 20:17
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 64a9f9e to 9c2d85e Compare October 1, 2025 21:41
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 9c2d85e to 30d1699 Compare October 7, 2025 22:36
@jjsjann123 jjsjann123 requested a review from naoyam October 7, 2025 23:23
@jjsjann123 jjsjann123 marked this pull request as ready for review October 7, 2025 23:23
jjsjann123 added a commit that referenced this pull request Oct 7, 2025
## Stacked PRs

Follow up PR on enabling python API and updating test_moe.py is still in
cleaning mode.
#5174 allow layout op in automatic scheduler
#5185 Fix allocation logic: unconnected alloc/logical    <- this one

## This PR

Fixes allocation logic to ensure that the output tensor has:
1. shape matching its logical domain;
2. buffer size matching the allocation domain.

Without this PR, the output tensor from `PreprocessGroupedMatmulInputSf`
will have a mismatch shape from its logical domain, causing validation
failure in downstream consumers.

### Context

PreprocessGroupedMatmulInputSf op has:
1. unconnected logical and allocation domain.
4. larger allocation size, because extra padding is represented via
arithmetic operations on the extent directly.

Existing allocation logic allocate buffer matches logical sizes/strides.
This is not the right behavior. Because allocation domain could have
larger extent. We also cannot use allocation sizes/strides neither,
because consumer of the tensor expects a tensor matching the logical
size.

We updated the logic to use allocation domain for buffer allocation.
Then we slice into the buffer using logical domain to produce the
correct-sized output.
For the case of PreprocessGroupedMatmulInputSf, because there's no
correct way to slice into the buffer for indexing, we give up on
producing correct strides and just use a naive stride instead. It's safe
to do so, since we are not using indexing logic on the output.

### Code change

1. refactor buffer allocation buffer to use allocation domain, instead
of logical domain.
5. fixing projection from allocation to logical special path when
projection is not possible: We now compute correct extent instead of
returning the allocation buffer as-is, this allows that layout op to
return a tensor with the correct logical size, while still allocating a
large enough buffer to accommodate the padding requirement.
Base automatically changed from jj/allocation_for_layout_op_PR_1 to main October 7, 2025 23:41
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 012365d to 6d0b512 Compare October 7, 2025 23:44
test case added

wip

prevent cacheAndForkOutputs

disabl cacheInputs for offsets TVs

change domain stuff in reference TV

revert unused changes

err something isn't working right

wip

adding more test examples

misc fixes

adding quick note for me to pick up later

break consumer of layout op into separate fusion

err

quick fix on logic

refactor to use resize for transform replay

clear allocation domain

revert resize change for allocation domain; wipe  out allocation for layout op on fusion segment boundaries

trying to patch allocation domain handling during allocation

fix cases where allocation domain isn't available

fixing allocation transform and fixing tests

fixing test sizes

I think I'm doing the right thing

fixing normalization scheduler example

clangformat

forgot to import namespaces

adding comment; warning; clangformat

fixing the safety check to after adding outputs

lintrunner

fixing errors from main changes; updating segmenter logic to force segment on layout op offset inputs to ensure they are in global memory

switching warning to asserts
@jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 6d0b512 to 350f318 Compare October 7, 2025 23:49
@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

for (auto inp : getAllInputs(sg)) {
auto clone_tv = complete_to_segment_map.clone(inp);
fusion_segment->addInput(clone_tv);
if (inp->isDefinitionType<ReshapeOp>()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this branch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why we need it and it doesn't have an error message. So my gut feeling is just remove it, since the assert is not adding more debugging insights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing there was something here for reshape. It's deleted but not completely.

NVF_ERROR(clone_tv != nullptr && clone_tv->isA<TensorView>());
view_tvs.push_back(clone_tv->as<TensorView>());
} else if (inp->isDefinitionType<PreprocessGroupedMatmulInputSf>()) {
// There's no point of replaying allocation domain if we cannot index into
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes, this is an input.

Is there any reason to have the assertion at line 1896 there, not here?

Also, where do you have a problem with TransformReplay? If the segment just contains the CUTLASS grouped MMAOp, where does TransformReplay bite us?


static bool hasResizeAndIndexOps(Fusion* fusion);

// Checks if fusion contains illegal non-indexable ops. E.g. for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. The name sounds a little weird to me. We already have rejectScheduleFusionInputRequirement. Why not just adds something similar for fusion outputs and use them?

@jjsjann123
Copy link
Collaborator Author

!test

bool rejectScheduleFusionOutputRequirement(
Expr* expr,
SchedulerType scheduler_type) {
TensorView* out = ir_utils::getTvOutput(expr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't all outputs need to be checked?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are only using it for layout op, which has only a single output TV, so I didn't bother doing that.

Conceptually, similar to inputs, we might only want to check certain output staying on global. I would argue that we should have
bool rejectScheduleFusionOutputRequirement( Expr* expr, Val* val, SchedulerType scheduler_type);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the consistent interface would look better, so +1 for rejectScheduleFusionOutputRequirement( Expr* expr, Val* val, SchedulerType scheduler_type);

@naoyam
Copy link
Collaborator

naoyam commented Oct 9, 2025

Is PreprocessGroupedMatmulInputSf ever be fused with some other ops? Or is it always by itself only?

@jjsjann123
Copy link
Collaborator Author

Is PreprocessGroupedMatmulInputSf ever be fused with some other ops? Or is it always by itself only?

We fuse it to its producers. (i.e. quantize math that produces the blocked scaling factor.

@jjsjann123
Copy link
Collaborator Author

!test

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

Failures are not related. merging as-is.

@jjsjann123 jjsjann123 merged commit 117a4b8 into main Oct 9, 2025
63 of 65 checks passed
@jjsjann123 jjsjann123 deleted the jj/layout_op_PR4_scheduler_support branch October 9, 2025 21:28
jjsjann123 added a commit that referenced this pull request Oct 16, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op  <-- this PR
#5174 allow layout op in automatic scheduler

## This PR

This is a tiny refactor to only expect the two `offsets` to have the
size equal to num_groups. The reason is that our cutlass kernel were
expecting that in the first place and I didn't match it right in the
first time.

e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
Previously the offsets would be [0, 2, 5, 10]; after the refactor, the
offsets would be [0, 2, 5].
jjsjann123 added a commit that referenced this pull request Oct 21, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding  <-- this PR
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler

## This PR

Expose layout op at python direct binding.
Added nvfp4 grouped gemm in python test.

Minor fixes:

1. ~Added support of allocation domain for output of layout op in
concretization pass to maintain the dependency on padded allocation
domain to its logical domain.~ No longer needed, handled in
#5384
2. Skipped validation for `setAllocationDomain`
3. updated reference implementation to match the math order in nvfuser
decomposed nvfp4 quantization.

TODO:

python tests requires IdModel Indexer in order to work. See issue #5200,
as well as suggested WAR in
#5200 (comment)
tbqh pushed a commit that referenced this pull request Nov 12, 2025
## Stacked PRs

Follow up PR on enabling python API and updating test_moe.py is still in
cleaning mode.
#5174 allow layout op in automatic scheduler
#5185 Fix allocation logic: unconnected alloc/logical    <- this one

## This PR

Fixes allocation logic to ensure that the output tensor has:
1. shape matching its logical domain;
2. buffer size matching the allocation domain.

Without this PR, the output tensor from `PreprocessGroupedMatmulInputSf`
will have a mismatch shape from its logical domain, causing validation
failure in downstream consumers.

### Context

PreprocessGroupedMatmulInputSf op has:
1. unconnected logical and allocation domain.
4. larger allocation size, because extra padding is represented via
arithmetic operations on the extent directly.

Existing allocation logic allocate buffer matches logical sizes/strides.
This is not the right behavior. Because allocation domain could have
larger extent. We also cannot use allocation sizes/strides neither,
because consumer of the tensor expects a tensor matching the logical
size.

We updated the logic to use allocation domain for buffer allocation.
Then we slice into the buffer using logical domain to produce the
correct-sized output.
For the case of PreprocessGroupedMatmulInputSf, because there's no
correct way to slice into the buffer for indexing, we give up on
producing correct strides and just use a naive stride instead. It's safe
to do so, since we are not using indexing logic on the output.

### Code change

1. refactor buffer allocation buffer to use allocation domain, instead
of logical domain.
5. fixing projection from allocation to logical special path when
projection is not possible: We now compute correct extent instead of
returning the allocation buffer as-is, this allows that layout op to
return a tensor with the correct logical size, while still allocating a
large enough buffer to accommodate the padding requirement.
tbqh pushed a commit that referenced this pull request Nov 12, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler  <-- this PR

## This PR

Allow scheduler to take `PreprocessGroupedMatmulInputSf` as a pointwise
operation using the runtime function.

The main code change is to addressing the assumption of the runtime
function:

- [x] add segmentation for offsets to ensure they are in global memory.
* Existing assumption is that two offsets inputs and output of the
layout op would be in global memory, where the runtime function could
read the entirety of both offsets and write the output via data
dependent indexing. This allows the operation to be treated as a trivial
pointwise-op.
    * avoids caching layout op outputs or offsets inputs.
* avoids putting layout op output into persistent buffers (since we
require write to global memory).

- [x] detect unsafe consumption of PreprocessGroupedMatmulInputSf output
in `fusion_segmenter.cpp`
- [x] relax asserts on assumption that there's always a legit path
between loop->allocation and logical->allocation in some scheduler
utils.


TODOs for future PR:

* end-2-end python test with direct binding.
tbqh pushed a commit that referenced this pull request Nov 12, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op  <-- this PR
#5174 allow layout op in automatic scheduler

## This PR

This is a tiny refactor to only expect the two `offsets` to have the
size equal to num_groups. The reason is that our cutlass kernel were
expecting that in the first place and I didn't match it right in the
first time.

e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
Previously the offsets would be [0, 2, 5, 10]; after the refactor, the
offsets would be [0, 2, 5].
tbqh pushed a commit that referenced this pull request Nov 12, 2025
## Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding  <-- this PR
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler

## This PR

Expose layout op at python direct binding.
Added nvfp4 grouped gemm in python test.

Minor fixes:

1. ~Added support of allocation domain for output of layout op in
concretization pass to maintain the dependency on padded allocation
domain to its logical domain.~ No longer needed, handled in
#5384
2. Skipped validation for `setAllocationDomain`
3. updated reference implementation to match the math order in nvfuser
decomposed nvfp4 quantization.

TODO:

python tests requires IdModel Indexer in order to work. See issue #5200,
as well as suggested WAR in
#5200 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants