Skip to content

refactor number of groups in layout op#5198

Merged
jjsjann123 merged 6 commits intomainfrom
jj/layout_op_PR5_python_direct_binding
Oct 16, 2025
Merged

refactor number of groups in layout op#5198
jjsjann123 merged 6 commits intomainfrom
jj/layout_op_PR5_python_direct_binding

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 19, 2025

Stacked PRs

#5230 moe layer with nvfp4 grouped_mm
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op <-- this PR
#5174 allow layout op in automatic scheduler

This PR

This is a tiny refactor to only expect the two offsets to have the size equal to num_groups. The reason is that our cutlass kernel were expecting that in the first place and I didn't match it right in the first time.

e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
Previously the offsets would be [0, 2, 5, 10]; after the refactor, the offsets would be [0, 2, 5].

@github-actions
Copy link

github-actions bot commented Sep 19, 2025

Review updated until commit a85c363

Description

  • Fix group indexing in layout op kernel and CPU logic

  • Refactor allocation domain logic for grouped matmul input

  • Adjust offset handling to match num_groups correctly

  • Update tests to reflect new offset semantics


Changes walkthrough 📝

Relevant files
Bug fix
indexing.cpp
Refactor and fix allocation domain for grouped matmul       

csrc/ops/indexing.cpp

  • Extracted allocation domain logic into layoutAllocationDomain
  • Fixed num_groups calculation using offset extent directly
  • Updated padding logic to use new group count
  • Refactored preprocessGroupedMatmulInputSf to use helper function
  • +45/-30 
    block_layout.cu
    Fix expert ID indexing in CUDA kernel                                       

    runtime/block_layout.cu

  • Fixed expert_id initialization and loop bounds
  • Adjusted indexing to match new offset semantics
  • Changed loop to start from i=1 for correct comparison
  • +4/-4     
    Miscellaneous
    transform_replay.cpp
    Add refactoring TODO for TensorDomain creation                     

    csrc/transform_replay.cpp

  • Added TODO for refactoring TensorDomain creation
  • Improved comment on TensorDomain duplication
  • No functional changes, only code clarity
  • +11/-2   
    Tests
    test_layout_op.cpp
    Update test logic for new offset handling                               

    tests/cpp/test_layout_op.cpp

  • Updated num_group to use offset size directly
  • Fixed padded_m calculation using last offset
  • Adjusted group slicing logic for correct bounds
  • Updated test tensors to remove extra offset values
  • +26/-14 
    Enhancement
    indexing.h
    Add helper function declaration and improve docs                 

    csrc/ops/indexing.h

  • Added declaration for layoutAllocationDomain helper
  • Improved documentation for layout op functions
  • No functional changes in header
  • +9/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function layoutAllocationDomain uses num_groups directly as the extent of input_offsets, but it's unclear if this correctly reflects the number of groups in all cases, especially considering the refactor from offsets of size num_groups + 1 to num_groups. This could lead to incorrect allocation domain calculations if num_groups is not properly derived.

    std::vector<IterDomain*> layoutAllocationDomain(
        std::vector<IterDomain*> logical_dom,
        Val* num_groups,
        BlockScalingFactorLayout layout) {
      NVF_ERROR_EQ(logical_dom.size(), 2);
    
      // Create the allocation domain of output.
      std::vector<IterDomain*> alloc_dom;
      alloc_dom.reserve(logical_dom.size());
    
      // only Block128x4 is supported at this point.
      NVF_CHECK_EQ(layout, BlockScalingFactorLayout::Block128x4);
      constexpr int col_multiple = 4;
      constexpr int row_multiple = 128;
    
      auto* one_val = num_groups->fusion()->oneVal(DataType::Index);
    
      // Note: output allocation domain handles potential padding required for the
      // layout. Since the actual padding size is data-dependent, we allocate for
      // the maximum padding (reflected on logical/allocation domain).
    
      // NOTE: We could use resize operations for the padding logic, I think this
      // might simplify predication. Not doing that for now for simpler
      // implementation. We'll re-evaluate when we add scheduler support.
    
      // pad row size: num_groups * (row_multiple - 1) + row_size
      auto pad_to_max_extent = [&](IterDomain* id, int multiple) -> IterDomain* {
        auto* maximum_pad_value_per_group =
            IrBuilder::create<Val>(multiple - 1, DataType::Index);
    
        // NOTE: we do not use `resize` to represent the padding.
        //
        // resize sounds good in theory, because transformation can propagate across
        // it. In reality, we do not have a protocol to index this operation via the
        // logical to allocation domain transform. I question how much a resize op
        // provides in functionality. More importantly, using resize hits asserts in
        // vectorization analysis (validateDeviceSplit ATM), which doesn't look easy
        // to handle for me.
        Val* padded_ext = SimplifyingIrBuilder::addExpr(
            id->extent(),
            SimplifyingIrBuilder::mulExpr(num_groups, maximum_pad_value_per_group));
        return IterDomainBuilder(id).extent(padded_ext).build();
      };
      alloc_dom.push_back(pad_to_max_extent(logical_dom[0], row_multiple));
    
      // pad col size: (col_size + col_multiple - 1) / col_multiple * col_multiple
      auto pad_to_multiple = [&](IterDomain* id, int multiple) -> IterDomain* {
        Val* ext = id->extent();
        auto* multiple_val = IrBuilder::create<Val>(multiple, DataType::Index);
        // Just as the comment above, we do NOT use resize op.
        Val* padded_ext = SimplifyingIrBuilder::mulExpr(
            SimplifyingIrBuilder::divExpr(
                SimplifyingIrBuilder::subExpr(
                    SimplifyingIrBuilder::addExpr(ext, multiple_val), one_val),
                multiple_val),
            multiple_val);
        return IterDomainBuilder(id).extent(padded_ext).build();
      };
      alloc_dom.push_back(pad_to_multiple(logical_dom[1], col_multiple));
    
      return alloc_dom;
    }
    Possible Issue

    The loop in preprocessGroupedMatmulInputSf starts from i = 1 and compares row_idx with input_offsets[i], assuming input_offsets has at least one element beyond the initial zero. With the refactor reducing the size of offsets, this may lead to out-of-bounds access if group_size is not adjusted accordingly.

    int expert_id = group_size - 1;
    for (int i = 1; i < group_size; ++i) {
      if (row_idx < input_offsets[i]) {
        expert_id = i - 1;
        break;
      }
    }
    Test Coverage

    The test cases use hardcoded offset values that assume specific group sizes. With the refactor, it's important to verify that the new offset logic (without the final total length) is correctly tested across various group configurations, including edge cases like single-group or empty groups.

      // memory, because we do indexing on output inside the runtime function.
      fusion.addOutput(out_tv);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      int m = 512;
      int k = 9; // note: padded column size would be 12
      // auto t0 = at::randn({m, k}, options);
      auto t0 = at::arange(m * k, options).reshape({m, k});
      // tokens per group are [100, 150, 262] respectively, so each group would be
      // padded to multiple of 128. Hence the total output row span would cover a
      // length of 128 + 256 + 384 = 768.
      auto t1 = at::tensor({0, 100, 250}, options.dtype(at::kInt));
      auto t2 = at::tensor({0, 128, 384}, options.dtype(at::kInt));
    
      // naive scheduling.
      for (auto tv : {inp, inp_tv, out_tv}) {
        tv->axis(0)->parallelize(ParallelType::BIDx);
        tv->axis(1)->parallelize(ParallelType::TIDx);
      }
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1, t2});
      auto outputs = ke.run({t0, t1, t2});
    
      ASSERT_TRUE(validateGroupedLayout(
          BlockScalingFactorLayout::Block128x4,
          outputs[0].as<at::Tensor>(),
          t0,
          t1,
          t2));
    }
    
    TEST_F(LayoutOpTest, SchedulerKernel) {
      auto fusion_ptr = std::make_unique<Fusion>();
      Fusion& fusion = *fusion_ptr.get();
      FusionGuard fg(&fusion);
    
      auto inp = makeSymbolicTensor(2);
      auto offsets = makeSymbolicTensor(1, DataType::Int32);
      auto rounded_offsets = makeSymbolicTensor(1, DataType::Int32);
      fusion.addInput(inp);
      fusion.addInput(offsets);
      fusion.addInput(rounded_offsets);
    
      auto out_tv = preprocessGroupedMatmulInputSf(
          inp, offsets, rounded_offsets, BlockScalingFactorLayout::Block128x4);
      fusion.addOutput(out_tv);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      int m = 512;
      int k = 9; // note: padded column size would be 12
      auto t0 = at::randn({m, k}, options);
      // tokens per group are [100, 150, 262] respectively, so each group would be
      // padded to multiple of 128. Hence the total output row span would cover a
      // length of 128 + 256 + 384 = 768.
      auto t1 = at::tensor({0, 100, 250}, options.dtype(at::kInt));
      auto t2 = at::tensor({0, 128, 384}, options.dtype(at::kInt));
    
      // running through automatic scheduler.
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2});
    
      ASSERT_TRUE(validateGroupedLayout(
          BlockScalingFactorLayout::Block128x4,
          outputs[0].as<at::Tensor>(),
          t0,
          t1,
          t2));
    }
    
    TEST_F(LayoutOpTest, SchedulerKernelWithConsumer) {
      auto fusion_ptr = std::make_unique<Fusion>();
      Fusion& fusion = *fusion_ptr.get();
      FusionGuard fg(&fusion);
    
      auto inp = makeSymbolicTensor(2);
      auto offsets = makeSymbolicTensor(1, DataType::Int32);
      auto rounded_offsets = makeSymbolicTensor(1, DataType::Int32);
      fusion.addInput(inp);
      fusion.addInput(offsets);
      fusion.addInput(rounded_offsets);
    
      auto out_tv = preprocessGroupedMatmulInputSf(
          inp, offsets, rounded_offsets, BlockScalingFactorLayout::Block128x4);
      fusion.addOutput(out_tv);
    
      // This is not allowed and we should error out since layout op output should
      // only be consumed by grouped_matmul op
      auto relu_tv = relu(out_tv);
      fusion.addOutput(relu_tv);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      int m = 512;
      int k = 9; // note: padded column size would be 12
      auto t0 = at::randn({m, k}, options);
      // tokens per group are [100, 150, 262] respectively, so each group would be
      // padded to multiple of 128. Hence the total output row span would cover a
      // length of 128 + 256 + 384 = 768.
      auto t1 = at::tensor({0, 100, 250}, options.dtype(at::kInt));
      auto t2 = at::tensor({0, 128, 384}, options.dtype(at::kInt));
    
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      EXPECT_ANY_THROW(executor_cache.runFusionWithInputs({t0, t1, t2}));
    }
    
    TEST_F(LayoutOpTest, SchedulerKernelWithOffsetsProducer) {
      auto fusion_ptr = std::make_unique<Fusion>();
      Fusion& fusion = *fusion_ptr.get();
      FusionGuard fg(&fusion);
    
      auto inp = makeSymbolicTensor(2);
      auto offsets = makeSymbolicTensor(1, DataType::Int32);
      auto rounded_offsets = makeSymbolicTensor(1, DataType::Int32);
      fusion.addInput(inp);
      fusion.addInput(offsets);
      fusion.addInput(rounded_offsets);
    
      // fusion should segment here, because layout op requires offsets to be in
      // global memory
      auto offsets_add = add(offsets, fusion.oneVal());
      auto rounded_offsets_add = add(rounded_offsets, fusion.oneVal());
    
      auto out_tv = preprocessGroupedMatmulInputSf(
          inp,
          offsets_add,
          rounded_offsets_add,
          BlockScalingFactorLayout::Block128x4);
      fusion.addOutput(out_tv);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      int m = 512;
      int k = 9; // note: padded column size would be 12
      auto t0 = at::randn({m, k}, options);
      // tokens per group are [100, 150, 262] respectively, so each group would be
      // padded to multiple of 128. Hence the total output row span would cover a
      // length of 128 + 256 + 384 = 768.
      auto t1 = at::tensor({0, 100, 250, 512}, options.dtype(at::kInt));
      auto t2 = at::tensor({0, 128, 384, 768}, options.dtype(at::kInt));
    
      // naive scheduling.
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs({t0, t1.sub(1), t2.sub(1)});
    
      ASSERT_TRUE(validateGroupedLayout(
          BlockScalingFactorLayout::Block128x4,
          outputs[0].as<at::Tensor>(),
          t0,
          t1,
          t2));
    }
    
    TEST_F(LayoutOpTest, SchedulerKernelWithExplicitQuantizationPattern) {
      auto fusion_ptr = std::make_unique<Fusion>();
      Fusion& fusion = *fusion_ptr.get();
      FusionGuard fg(&fusion);
    
      auto inp = makeSymbolicTensor(2);
      auto offsets = makeSymbolicTensor(1, DataType::Int32);
      auto rounded_offsets = makeSymbolicTensor(1, DataType::Int32);
      fusion.addInput(inp);
      fusion.addInput(offsets);
      fusion.addInput(rounded_offsets);
    
      auto block_size = IrBuilder::create<Val>(16, DataType::Int);
      auto remainder = ceilDiv(inp->axis(1)->extent(), block_size);
    
      auto reshaped_inp =
          reshape(inp, {inp->axis(0)->extent(), remainder, block_size});
      auto blocked_sf = max(reshaped_inp, {2});
      auto scaled_output = reshape(
          div(reshaped_inp, broadcast(blocked_sf, {false, false, true})),
          {inp->axis(0)->extent(), inp->axis(1)->extent()});
      // NOTE: output needs to be casted to DataType::Float4_e2m1fn, skipping that
      // for simplicity for validation
      fusion.addOutput(scaled_output);
    
      auto out_blocked_sf_fp8 = preprocessGroupedMatmulInputSf(
          blocked_sf,
          offsets,
          rounded_offsets,
          BlockScalingFactorLayout::Block128x4);
      // NOTE: output needs to be casted to DataType::Float8_e4m3fn, skipping that
      // for simplicity for validation
      fusion.addOutput(out_blocked_sf_fp8);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      int m = 512;
      int k = 9 * 16; // note: padded column size needs to be a multiple of 16
      auto t0 = at::randn({m, k}, options);
      // tokens per group are [100, 150, 262] respectively, so each group would be
      // padded to multiple of 128. Hence the total output row span would cover a
      // length of 128 + 256 + 384 = 768.
      auto t1 = at::tensor({0, 100, 250}, options.dtype(at::kInt));
      auto t2 = at::tensor({0, 128, 384}, options.dtype(at::kInt));
    
      // automatic scheduling.
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2});

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 34b394d to 3f56fcc Compare September 22, 2025 21:54
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch 2 times, most recently from c40c78f to 58885a8 Compare September 22, 2025 22:34
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from c6ea939 to 64a9f9e Compare September 23, 2025 17:41
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from 5e62117 to 9547246 Compare September 23, 2025 17:41
    jjsjann123 added a commit that referenced this pull request Sep 24, 2025
    Shouldn't have sliced swizzled & padded data when copying it back to the
    original buffer. The issue was noticed when I try to validate the layout
    op in #5198
    
    This unfortunately didn't affect the threshold for result validation.
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from 39fa2b3 to b3fbe15 Compare October 1, 2025 18:18
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch 2 times, most recently from 9c2d85e to 30d1699 Compare October 7, 2025 22:36
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from b3fbe15 to 64d1651 Compare October 7, 2025 23:27
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 012365d to 6d0b512 Compare October 7, 2025 23:44
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from 64d1651 to 79e1f8c Compare October 7, 2025 23:44
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR4_scheduler_support branch from 6d0b512 to 350f318 Compare October 7, 2025 23:49
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch 2 times, most recently from a110e89 to a9821a0 Compare October 8, 2025 11:58
    @jjsjann123 jjsjann123 changed the title end-2-end nvfp4 grouped_mm via direct binding refactor number of groups in layout op Oct 8, 2025
    @jjsjann123 jjsjann123 mentioned this pull request Oct 8, 2025
    3 tasks
    // set it as IterType::Symbolic to avoid domain validation errors
    return IterDomainBuilder(id)
    .extent(padded_ext)
    .iter_type(IterType::Symbolic)
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    another change

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What validation are you referring to?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is coming from
    nvfuser::ir_utils::validateDomainEquivalence

    The error occurs during transformation replay, specifically, during transform propagation.

    (gdb) f 5
    #5  0x0000aaaaabd018fc in nvfuser::TransformReplay::replayCasP (consumer=0xfffc8c00d720, producer=0xfffc8c005fc0, 
        producer_pos=2, logical_map=..., opt=...) at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:870
    870         TensorDomain* replayed = IrBuilder::createInContainer<TensorDomain>(
    (gdb) p consumer->printTransform()
    Couldn't find method nvfuser::TensorView::printTransform
    (gdb) p consumer->printTransforms()
     logical domain : (iS54{i0}, iS55{9})
     allocation domain : (iS56{( i0 + ( i3 * 127 ) )}, iS57{( ( 12 / 4 ) * 4 )})
     contiguity: t t
     loop domain : (iS54{i0}, iS55{9})
    $3 = void
    (gdb) p producer->printTransforms()
     logical domain : (iS39{i0}, iS40{9}, rS11{16})
     contiguity: t t n
     loop domain : (iS40{9}, rS11{16}, iS39{i0})
    $4 = void
    (gdb) p new_loop[0]->toString(0)
    $5 = "iS55{9}"
    (gdb) p new_loop[1]->toString(0)
    $6 = "iS54{i0}"
    (gdb) 
    

    The program tried to update the consumer's domain as

     870     TensorDomain* replayed = IrBuilder::createInContainer<TensorDomain>(
     871         consumer->container(),
     872         consumer->getRootDomain(),
     873         consumer->getLogicalDomain(),
     874         consumer->getAllocationDomain(),
     875         new_loop,
     876         consumer->domain()->contiguity());
    

    It won't work as the allocation domain doesn't match the logical domain. However, by marking padded allocation domain as symbolic, ir_utils::compareDomains wouldn't try to map it and that silences the error.

    Here's the full stack

    #2  0x0000aaaaab64e5bc in nvfuser::ir_utils::validateDomainEquivalence (
        dom0=std::vector of length 2, capacity 2 = {...}, dom1=std::vector of length 2, capacity 2 = {...},
        additional_ids=std::vector of length 0, capacity 0) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1163
    #3  0x0000aaaaab4ea6e0 in nvfuser::TensorDomain::TensorDomain (this=0xfffc8c037c70, passkey=...,
        root_domain=std::vector of length 0, capacity 0, logical_domain=std::vector of length 0, capacity 0,
        allocation_domain=std::vector of length 0, capacity 0, loop_domain=std::vector of length 0, capacity 0,
        contiguity=std::vector of length 0, capacity 0, additional_ids=std::vector of length 0, capacity 0)
        at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:3417
    #4  0x0000aaaaabd0a36c in nvfuser::IrBuilder::createInContainer<nvfuser::TensorDomain, std::vector<nvfuser::IterDomain
    *, std::allocator<nvfuser::IterDomain*> > const&, std::vector<nvfuser::IterDomain*, std::allocator<nvfuser::IterDomain
    *> > const&, std::vector<nvfuser::IterDomain*, std::allocator<nvfuser::IterDomain*> > const&, std::vector<nvfuser::Ite
    rDomain*, std::allocator<nvfuser::IterDomain*> >&, std::vector<std::optional<bool>, std::allocator<std::optional<bool>
     > > const&> (container=0xfffc8c000e30) at /opt/pytorch/nvfuser/csrc/ir/builder.h:46
    #5  0x0000aaaaabd018fc in nvfuser::TransformReplay::replayCasP (consumer=0xfffc8c00d720, producer=0xfffc8c005fc0,
        producer_pos=2, logical_map=..., opt=...) at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:870
    #6  0x0000aaaaabd0216c in nvfuser::TransformReplay::replayCasP (consumer=0xfffc8c00d720, producer=0xfffc8c005fc0,
        compute_at_axis=2, opt=...) at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:948
    #7  0x0000aaaaabd035dc in nvfuser::TransformPropagator::propagateP2C (this=0xfffcd5db9d50, from=0xfffc8c005fc0,
        to=0xfffc8c00d720) at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:1206
    #8  0x0000aaaaabc22428 in nvfuser::MaxInfoSpanningTree::traverse (this=0xfffcd5db9d90, propagator=0xfffcd5db9d50)
        at /opt/pytorch/nvfuser/csrc/scheduler/tools/maxinfo_propagator.cpp:158
    #9  0x0000aaaaabc612e0 in nvfuser::scheduler_utils::transformPropagateToAllFrom (from_tv=0xfffc8c00f710, pos=2)
        at /opt/pytorch/nvfuser/csrc/scheduler/utils.cpp:1935
    #10 0x0000aaaaabc648c8 in nvfuser::scheduler_utils::propagateReshapeTransforms (fusion=0xfffc8c000e30, ca_map=...)
        at /opt/pytorch/nvfuser/csrc/scheduler/utils.cpp:2460
    #11 0x0000aaaaabb62fb4 in nvfuser::normalization_scheduler_utils::scheduleReductionGeneral (fusion=0xfffc8c000e30,
        rparams=0xaaaab29a03f0, reduction_tvs=std::vector of length 1, capacity 1 = {...},
        scheduler_type=nvfuser::SchedulerType::InnerPersistent)
        at /opt/pytorch/nvfuser/csrc/scheduler/normalization_utils.cpp:1527
    #12 0x0000aaaaabb633c8 in nvfuser::normalization_scheduler_utils::schedulePersistentKernel (fusion=0xfffc8c000e30, 
        rparams=0xaaaab29a03f0, scheduler_type=nvfuser::SchedulerType::InnerPersistent)
        at /opt/pytorch/nvfuser/csrc/scheduler/normalization_utils.cpp:1586
    #13 0x0000aaaaabaef90c in nvfuser::InnerPersistentKernelScheduler::schedule (this=0xfffc8c01ed20, 
        fusion=0xfffc8c000e30, params=0xaaaab29a03f0) at /opt/pytorch/nvfuser/csrc/scheduler/normalization_inner.cpp:1226
    

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    OK, scatter has a similar problem with the loop domain, so I added the skip_loop_validaiton flag. I don't quite like the idea of using IterType::Symbolic as it's meant to indicate the iteration type is still unknown, whereas in this case that isn't the case.

    Have you considered any other workarounds?

    Copy link
    Collaborator Author

    @jjsjann123 jjsjann123 Oct 14, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Initially I tried to follow the same route and add skip_validation for setAllocationDomain. The reason I used this hack was a a coincidence. (i.e. this issue didn't show up until I fixed dynamic_transform in #5345, and why comparing the fusion IR side by side, I realized symbolic iter type would let me get away with it 😛 ).

    Have you considered any other workarounds?

    IMHO, we shouldn't be validating domain equivalence, since the new direction we are moving towards is trying to genuine support this as a feature. Or rather, we should instead opt-in to run the validation.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    sorry about that typo. don't know where that genuine was coming from.

    There are certain properties that must hold, and making sure those properties are satisfied seems like a good idea to me.

    I agree that validation on proerties that must hold is a good idea.

    Given that we now started using allocation domain to represent padding/swizzle, similarly we use loop domain to represent sparse update in scatter. I think that's saying that domain equivalence between logical to allocation, as well as logical to loop, no longer hold and we shouldn't require the equivalency during transformations.

    We can certainly still imply those checks, but as an opt-in for cases when we know such properties hold.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    auto replayed = TensorDomainBuilder(consumer).loop(new_loop).build();

    Sounds like a reasonable WAR to try for this.

    I think the scatter case where the loop domain on the consumer doesn't converge to the logical domain would still be unhappy with it. Though we haven't got any cases where we replay across a scatter op in the scheduler.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Given that we now started using allocation domain to represent padding/swizzle, similarly we use loop domain to represent sparse update in scatter. I think that's saying that domain equivalence between logical to allocation, as well as logical to loop, no longer hold and we shouldn't require the equivalency during transformations.

    Please think about what we could do to make the overall system more robust rather than just throwing away safety checks. In the case of allocation domains, it seems to be that we should at least check if each iter domain is reachable and symbolically validate the size of an allocation domain is equal to or larger than that of the logical domain.

    Sounds like a reasonable WAR to try for this.

    To be clear, I didn't consider it's just a WAR but proposed as a valid design. If something is already valid, why should we need to validate it again?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Got'ya.

    In the case of allocation domains, it seems to be that we should at least check if each iter domain is reachable and symbolically validate the size of an allocation domain is equal to or larger than that of the logical domain.

    Does that mean we would need to have ID op connecting logical to allocation domain (like a resize), rather than relying on expression directly on the extent like in our current implementation?

    Fuser/csrc/ops/indexing.cpp

    Lines 353 to 379 in 6219364

    // Note: output logical domain handles potential padding required for the
    // layout. Since the actual padding size is data-dependent, we allocate for
    // the maximum padding (reflected on logical/allocation domain).
    // NOTE: We could use resize operations for the padding logic, I think this
    // might simplify predication. Not doing that for now for simpler
    // implementation. We'll re-evaluate when we add scheduler support.
    // pad row size: num_groups * (row_multiple - 1) + row_size
    auto pad_to_max_extent = [&](IterDomain* id, int multiple) -> IterDomain* {
    auto* maximum_pad_value_per_group =
    IrBuilder::create<Val>(multiple - 1, DataType::Index);
    // NOTE: we do not use `resize` to represent the padding.
    //
    // resize sounds good in theory, because transformation can propagate across
    // it. In reality, we do not have a protocol to index this operation via the
    // logical to allocation domain transform. I question how much a resize op
    // provides on functionality. More importantly, resize still hits asserts in
    // vectorization analysis (validateDeviceSplit ATM), which doesn't look easy
    // to handle for me.
    Val* padded_ext = SimplifyingIrBuilder::addExpr(
    id->extent(),
    SimplifyingIrBuilder::mulExpr(num_groups, maximum_pad_value_per_group));
    return IterDomainBuilder(id).extent(padded_ext).build();
    };
    out_alloc_dom.push_back(pad_to_max_extent(out_logical_dom[0], row_multiple));

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Opened #5383

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from e749d3f to 17dd08a Compare October 8, 2025 18:14
    @jjsjann123 jjsjann123 marked this pull request as ready for review October 8, 2025 18:14
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 requested a review from naoyam October 8, 2025 18:15
    jjsjann123 added a commit that referenced this pull request Oct 9, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding
    #5198 refactor number of groups in layout op
    #5174 allow layout op in automatic scheduler  <-- this PR
    
    ## This PR
    
    Allow scheduler to take `PreprocessGroupedMatmulInputSf` as a pointwise
    operation using the runtime function.
    
    The main code change is to addressing the assumption of the runtime
    function:
    
    - [x] add segmentation for offsets to ensure they are in global memory.
    * Existing assumption is that two offsets inputs and output of the
    layout op would be in global memory, where the runtime function could
    read the entirety of both offsets and write the output via data
    dependent indexing. This allows the operation to be treated as a trivial
    pointwise-op.
        * avoids caching layout op outputs or offsets inputs.
    * avoids putting layout op output into persistent buffers (since we
    require write to global memory).
    
    - [x] detect unsafe consumption of PreprocessGroupedMatmulInputSf output
    in `fusion_segmenter.cpp`
    - [x] relax asserts on assumption that there's always a legit path
    between loop->allocation and logical->allocation in some scheduler
    utils.
    
    
    TODOs for future PR:
    
    * end-2-end python test with direct binding.
    Base automatically changed from jj/layout_op_PR4_scheduler_support to main October 9, 2025 21:28
    @naoyam
    Copy link
    Collaborator

    naoyam commented Oct 9, 2025

    @jjsjann123 Can you fix the conflicts to clean up the diffs?

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from 17dd08a to f669051 Compare October 10, 2025 00:44
    @jjsjann123
    Copy link
    Collaborator Author

    @jjsjann123 Can you fix the conflicts to clean up the diffs?

    cleaned up now. I wish github would at least try a rebase when the target merged. 🤷

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM. Thanks for the update.

    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 merged commit 01dfc85 into main Oct 16, 2025
    58 of 60 checks passed
    @jjsjann123 jjsjann123 deleted the jj/layout_op_PR5_python_direct_binding branch October 16, 2025 20:11
    jjsjann123 added a commit that referenced this pull request Oct 21, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding  <-- this PR
    #5198 refactor number of groups in layout op
    #5174 allow layout op in automatic scheduler
    
    ## This PR
    
    Expose layout op at python direct binding.
    Added nvfp4 grouped gemm in python test.
    
    Minor fixes:
    
    1. ~Added support of allocation domain for output of layout op in
    concretization pass to maintain the dependency on padded allocation
    domain to its logical domain.~ No longer needed, handled in
    #5384
    2. Skipped validation for `setAllocationDomain`
    3. updated reference implementation to match the math order in nvfuser
    decomposed nvfp4 quantization.
    
    TODO:
    
    python tests requires IdModel Indexer in order to work. See issue #5200,
    as well as suggested WAR in
    #5200 (comment)
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding
    #5198 refactor number of groups in layout op
    #5174 allow layout op in automatic scheduler  <-- this PR
    
    ## This PR
    
    Allow scheduler to take `PreprocessGroupedMatmulInputSf` as a pointwise
    operation using the runtime function.
    
    The main code change is to addressing the assumption of the runtime
    function:
    
    - [x] add segmentation for offsets to ensure they are in global memory.
    * Existing assumption is that two offsets inputs and output of the
    layout op would be in global memory, where the runtime function could
    read the entirety of both offsets and write the output via data
    dependent indexing. This allows the operation to be treated as a trivial
    pointwise-op.
        * avoids caching layout op outputs or offsets inputs.
    * avoids putting layout op output into persistent buffers (since we
    require write to global memory).
    
    - [x] detect unsafe consumption of PreprocessGroupedMatmulInputSf output
    in `fusion_segmenter.cpp`
    - [x] relax asserts on assumption that there's always a legit path
    between loop->allocation and logical->allocation in some scheduler
    utils.
    
    
    TODOs for future PR:
    
    * end-2-end python test with direct binding.
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding
    #5198 refactor number of groups in layout op  <-- this PR
    #5174 allow layout op in automatic scheduler
    
    ## This PR
    
    This is a tiny refactor to only expect the two `offsets` to have the
    size equal to num_groups. The reason is that our cutlass kernel were
    expecting that in the first place and I didn't match it right in the
    first time.
    
    e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
    Previously the offsets would be [0, 2, 5, 10]; after the refactor, the
    offsets would be [0, 2, 5].
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding  <-- this PR
    #5198 refactor number of groups in layout op
    #5174 allow layout op in automatic scheduler
    
    ## This PR
    
    Expose layout op at python direct binding.
    Added nvfp4 grouped gemm in python test.
    
    Minor fixes:
    
    1. ~Added support of allocation domain for output of layout op in
    concretization pass to maintain the dependency on padded allocation
    domain to its logical domain.~ No longer needed, handled in
    #5384
    2. Skipped validation for `setAllocationDomain`
    3. updated reference implementation to match the math order in nvfuser
    decomposed nvfp4 quantization.
    
    TODO:
    
    python tests requires IdModel Indexer in order to work. See issue #5200,
    as well as suggested WAR in
    #5200 (comment)
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants