Skip to content

moe layer with nvfp4 grouped_mm#5230

Draft
jjsjann123 wants to merge 16 commits intomainfrom
jj/moe_fp4
Draft

moe layer with nvfp4 grouped_mm#5230
jjsjann123 wants to merge 16 commits intomainfrom
jj/moe_fp4

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 24, 2025

Stacked PRs

#5230 moe layer with nvfp4 grouped_mm <-- this PR
#5345 exposing layout op at direct python binding
#5198 refactor number of groups in layout op
#5174 allow layout op in automatic scheduler

This PR

Fixes that I want to be reviewed:

  • GatherScatter IterType shouldn't propagate in fusoin.
  • packed fp4 dtype is not properly handled in direct binding.

Changes on test_moe.py is just a display-of-concept:

  • Enable nvfuser nvfp4 grouped_mm in MoE benchmark.

Note: This has the indexing issue in #5200. This branch also requires some thunder inflight PRs.

@github-actions
Copy link

github-actions bot commented Sep 24, 2025

Review updated until commit 13db45d

Description

  • Added NVFP4 quantized grouped matrix multiplication support

  • Fixed GatherScatter IterType propagation in fusion

  • Enhanced MoE layer with FP4 weight handling

  • Updated group size calculation with total tokens


Changes walkthrough 📝

Relevant files
Enhancement
test_moe.py
Added NVFP4 grouped MM with quantization support                 

tests/python/test_moe.py

  • Added _group_sizes_from_offsets with total_tokens parameter for
    accurate group sizing
  • Introduced nvfuser_f16a_nvfp4weight_scaled_grouped_mm custom op with
    fake implementation
  • Implemented gmm_nvfuser translator for NVFuser with FP4 quantization
    logic
  • Enhanced GroupedLinear to support FP4 weights, scaling factors, and
    problem sizes
  • Updated GroupedSwiGLU and Llama4MoE to handle new FP4 parameters
  • +205/-9 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function _group_sizes_from_offsets now takes an additional total_tokens parameter, but its logic assumes offsets[1:] is valid. If offsets is empty or has only one element, this may lead to incorrect group sizes or runtime errors.

    def _group_sizes_from_offsets(offsets: torch.Tensor, total_tokens: int) -> list[int]:
        group_sizes = []
        prev = 0
        for offset in offsets[1:]:
            group_sizes.append(offset - prev)
            prev = offset
        group_sizes.append(total_tokens - prev)
        return group_sizes
    Performance Concern

    The custom op nvfuser_f16a_nvfp4weight_scaled_grouped_mm currently returns the result of grouped_mm using dropme instead of utilizing the quantized fp4_weight and scaling logic, which may defeat the purpose of the nvfp4 optimization.

    @torch.library.custom_op(
        "nvf_cutlass::f16a_nvfp4weight_scaled_grouped_mm", mutates_args=()
    )
    def nvfuser_f16a_nvfp4weight_scaled_grouped_mm(
        activation: torch.Tensor,
        fp4_weight: torch.Tensor,
        weight_scaling_factor: torch.Tensor,
        global_scale: torch.Tensor,
        offsets: torch.Tensor,
        blockscale_offsets: torch.Tensor,
        problem_sizes: torch.Tensor,
        dropme: torch.Tensor,
    ) -> torch.Tensor:
        return grouped_mm(activation, dropme, offsets)
    Correctness Risk

    In gmm_nvfuser, the problem_sizes tensor is constructed using tokens_per_expert.unsqueeze(-1) and concatenated with self.n and self.k, but it's unclear if the shape and order match the expected input for cutlass_nvfp4_grouped_mm, risking incorrect kernel behavior.

    m_size = activation.shape[0]
    k_size = activation.shape[1]
    k_tile_size = k_size // 16

    Comment on lines +170 to +171
    #from thunder.torch.custom_op import _register_nvfuser_translator
    #_register_nvfuser_translator(_sym_of_nvfp4_scaled_grouped_mm, gmm_nvfuser)
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nit-picking, this function just calls one function twice --

        from thunder.executors.nvfuserex_impl import register_supported
        from thunder.executors.torchex import _always_executable
    
        register_supported(symbol, translator_for_nvfuser, checker or _always_executable)
        register_supported(symbol.id, translator_for_nvfuser, checker or _always_executable)

    I'm taking a bit of time cleaning up the Lightning-AI/lightning-thunder#2481 tests. To reduce the number of cherry-picks, just calling register_supported here might make things easier

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from 39fa2b3 to b3fbe15 Compare October 1, 2025 18:18
    @jjsjann123 jjsjann123 force-pushed the jj/moe_fp4 branch 2 times, most recently from c4b9629 to 529149d Compare October 1, 2025 19:48
    @wujingyue wujingyue requested a review from zasdfgbnm October 3, 2025 19:34
    @jjsjann123 jjsjann123 force-pushed the jj/moe_fp4 branch 2 times, most recently from 57a0da9 to 80888d9 Compare October 7, 2025 23:24
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch 2 times, most recently from 64d1651 to 79e1f8c Compare October 7, 2025 23:44
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from 79e1f8c to a110e89 Compare October 7, 2025 23:49
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR5_python_direct_binding branch from a110e89 to a9821a0 Compare October 8, 2025 11:58
    @jjsjann123 jjsjann123 changed the base branch from jj/layout_op_PR5_python_direct_binding to jj/layout_op_PR6_direct_binding October 8, 2025 12:30
    @jjsjann123 jjsjann123 changed the title Jj/moe fp4 moe layer with nvfp4 grouped_mm Oct 8, 2025
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from 4c2e30a to c342568 Compare October 8, 2025 18:18
    @jjsjann123 jjsjann123 force-pushed the jj/moe_fp4 branch 2 times, most recently from c63a9f6 to 5ff4e24 Compare October 8, 2025 21:00
    @jjsjann123
    Copy link
    Collaborator Author

    I'm undecided how I would want to proceed with tests. But the other smaller pieces should be good for review.

    @jjsjann123 jjsjann123 marked this pull request as ready for review October 8, 2025 21:08
    @jjsjann123 jjsjann123 requested a review from crcrpar October 8, 2025 21:08
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from bc84003 to 210bbe5 Compare October 10, 2025 00:44
    @jjsjann123
    Copy link
    Collaborator Author

    I'll change target and refactor this to only handle nvfp4 via direct binding. since @Priya2698 is cherry-picking the propagation in #5365

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from de4645d to b4a3a2f Compare October 14, 2025 09:04
    jjsjann123 added a commit that referenced this pull request Oct 16, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding
    #5198 refactor number of groups in layout op  <-- this PR
    #5174 allow layout op in automatic scheduler
    
    ## This PR
    
    This is a tiny refactor to only expect the two `offsets` to have the
    size equal to num_groups. The reason is that our cutlass kernel were
    expecting that in the first place and I didn't match it right in the
    first time.
    
    e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
    Previously the offsets would be [0, 2, 5, 10]; after the refactor, the
    offsets would be [0, 2, 5].
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR6_direct_binding branch from b4a3a2f to fe8da62 Compare October 16, 2025 20:23
    jjsjann123 added a commit that referenced this pull request Oct 17, 2025
    Cherry-picked from #5230 
    
    * packed fp4 dtype needs to be supported by python API in order to
    support framework integration.
    
    FusionDefinition is not expecting to have packed dtype. But since that's
    the only fp4 dtype supported by framework, our integration would still
    need to support it.
    
    This PR adds a quick translation at `FusionDefinition.define_tensor` to
    translate packed dtype into unpacked dtype to keep the WAR transparent
    to integration/user.
    jjsjann123 added a commit that referenced this pull request Oct 21, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding  <-- this PR
    #5198 refactor number of groups in layout op
    #5174 allow layout op in automatic scheduler
    
    ## This PR
    
    Expose layout op at python direct binding.
    Added nvfp4 grouped gemm in python test.
    
    Minor fixes:
    
    1. ~Added support of allocation domain for output of layout op in
    concretization pass to maintain the dependency on padded allocation
    domain to its logical domain.~ No longer needed, handled in
    #5384
    2. Skipped validation for `setAllocationDomain`
    3. updated reference implementation to match the math order in nvfuser
    decomposed nvfp4 quantization.
    
    TODO:
    
    python tests requires IdModel Indexer in order to work. See issue #5200,
    as well as suggested WAR in
    #5200 (comment)
    Base automatically changed from jj/layout_op_PR6_direct_binding to main October 21, 2025 04:34
    @jjsjann123
    Copy link
    Collaborator Author

    All code changes have been merged in separate PRs, it's only the test_moe.py that's being updated in this PR. I'll clean it up and request a sanity check afterwards.

    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding
    #5198 refactor number of groups in layout op
    #5174 allow layout op in automatic scheduler  <-- this PR
    
    ## This PR
    
    Allow scheduler to take `PreprocessGroupedMatmulInputSf` as a pointwise
    operation using the runtime function.
    
    The main code change is to addressing the assumption of the runtime
    function:
    
    - [x] add segmentation for offsets to ensure they are in global memory.
    * Existing assumption is that two offsets inputs and output of the
    layout op would be in global memory, where the runtime function could
    read the entirety of both offsets and write the output via data
    dependent indexing. This allows the operation to be treated as a trivial
    pointwise-op.
        * avoids caching layout op outputs or offsets inputs.
    * avoids putting layout op output into persistent buffers (since we
    require write to global memory).
    
    - [x] detect unsafe consumption of PreprocessGroupedMatmulInputSf output
    in `fusion_segmenter.cpp`
    - [x] relax asserts on assumption that there's always a legit path
    between loop->allocation and logical->allocation in some scheduler
    utils.
    
    
    TODOs for future PR:
    
    * end-2-end python test with direct binding.
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding
    #5198 refactor number of groups in layout op  <-- this PR
    #5174 allow layout op in automatic scheduler
    
    ## This PR
    
    This is a tiny refactor to only expect the two `offsets` to have the
    size equal to num_groups. The reason is that our cutlass kernel were
    expecting that in the first place and I didn't match it right in the
    first time.
    
    e.g. with total sequence length 10, and tokens per expert [2, 3, 5]
    Previously the offsets would be [0, 2, 5, 10]; after the refactor, the
    offsets would be [0, 2, 5].
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    Cherry-picked from #5230 
    
    * packed fp4 dtype needs to be supported by python API in order to
    support framework integration.
    
    FusionDefinition is not expecting to have packed dtype. But since that's
    the only fp4 dtype supported by framework, our integration would still
    need to support it.
    
    This PR adds a quick translation at `FusionDefinition.define_tensor` to
    translate packed dtype into unpacked dtype to keep the WAR transparent
    to integration/user.
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    ## Stacked PRs
    
    #5230 moe layer with nvfp4 grouped_mm
    #5345 exposing layout op at direct python binding  <-- this PR
    #5198 refactor number of groups in layout op
    #5174 allow layout op in automatic scheduler
    
    ## This PR
    
    Expose layout op at python direct binding.
    Added nvfp4 grouped gemm in python test.
    
    Minor fixes:
    
    1. ~Added support of allocation domain for output of layout op in
    concretization pass to maintain the dependency on padded allocation
    domain to its logical domain.~ No longer needed, handled in
    #5384
    2. Skipped validation for `setAllocationDomain`
    3. updated reference implementation to match the math order in nvfuser
    decomposed nvfp4 quantization.
    
    TODO:
    
    python tests requires IdModel Indexer in order to work. See issue #5200,
    as well as suggested WAR in
    #5200 (comment)
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants