Skip to content

Fix itertype promotion for GatherScatter#5365

Merged
Priya2698 merged 2 commits intomainfrom
pm/gatherscatter_bug
Oct 10, 2025
Merged

Fix itertype promotion for GatherScatter#5365
Priya2698 merged 2 commits intomainfrom
pm/gatherscatter_bug

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Oct 9, 2025

This PR

Stops propagation of IterType::GatherScatter to be propagated from producer to consumer.

Context

We are propagating IterType::GatherScatter, in the example below, the output T2 from scatter op contains nS2 in logical domain. That IterType is propagated through to its consumer T3, which felt wrong.

%kernel {
T2_l_int64_t[iS3{1024}]
   = scatter(in = T0_g_int64_t[iS0{128}], dim = 0, src = 1, idx = T1_g_int64_t[iS1{1024}], accumulate = add )
T3_g_int64_t[nS4{128}]
   = T2_l_int64_t[iS3{1024}]
   + 1;

T2_l_int64_t[iS3{1024}]
 logical domain : (nS2{128})
 contiguity: t
 loop domain : (iS3{1024})
T3_g_int64_t[nS4{128}]
 logical domain : (nS4{128})
 contiguity: t
 loop domain : (nS4{128})
} // %kernel

@Priya2698 Priya2698 requested a review from jjsjann123 October 9, 2025 23:46
@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Oct 9, 2025

Review updated until commit 12d6f71

Description

  • Fix iter_type promotion for GatherScatter axes

  • Add test for GatherScatter iter type behavior

  • Ensure non-Scatter Tvs don't have GatherScatter axes


Changes walkthrough 📝

Relevant files
Bug fix
utils.cpp
Fix iter_type promotion for GatherScatter                               

csrc/ops/utils.cpp

  • Fix iter_type handling for GatherScatter IDs
  • Set iter_type to Iteration when ID is GatherScatter
  • Preserve existing promotion logic otherwise
  • +2/-0     
    Tests
    test_scatter.cpp
    Add test for GatherScatter iter type                                         

    tests/cpp/test_scatter.cpp

  • Add test case for GatherScatter iter type
  • Validate loop domains exclude GatherScatter axes
  • Test non-Scatter tensors in fusion
  • +25/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Logic Error

    The promotion of iter_type for GatherScatter axes may not correctly handle all cases, as setting IterType::Iteration unconditionally on isGatherScatter could override meaningful iter types from other paths. This may affect downstream optimizations relying on accurate iter_type.

    } else if (id->isGatherScatter()) {
      iter_type = IterType::Iteration;
    } else {


    namespace nvfuser {

    TEST_F(NVFuserTest, GatherScatterIterType) {
    Copy link
    Collaborator Author

    @Priya2698 Priya2698 Oct 10, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is same as ScatterLargeConstrainedIDs.
    I have made this separate to keep it independent of the other test in case it is modified in the future according to its purpose.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    @Priya2698 Priya2698 merged commit 922d6bd into main Oct 10, 2025
    65 checks passed
    @Priya2698 Priya2698 deleted the pm/gatherscatter_bug branch October 10, 2025 17:01
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    ## This PR
    
    Stops propagation of IterType::GatherScatter to be propagated from
    producer to consumer.
    
    ## Context
    
    We are propagating IterType::GatherScatter, in the example below, the
    output T2 from scatter op contains `nS2` in logical domain. That
    IterType is propagated through to its consumer T3, which felt wrong.
    
    ```
    %kernel {
    T2_l_int64_t[iS3{1024}]
       = scatter(in = T0_g_int64_t[iS0{128}], dim = 0, src = 1, idx = T1_g_int64_t[iS1{1024}], accumulate = add )
    T3_g_int64_t[nS4{128}]
       = T2_l_int64_t[iS3{1024}]
       + 1;
    
    T2_l_int64_t[iS3{1024}]
     logical domain : (nS2{128})
     contiguity: t
     loop domain : (iS3{1024})
    T3_g_int64_t[nS4{128}]
     logical domain : (nS4{128})
     contiguity: t
     loop domain : (nS4{128})
    } // %kernel
    ```
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants