Skip to content

More precise WAR for resize vectorization#4305

Merged
naoyam merged 6 commits intomainfrom
fix_resize_vec
Apr 25, 2025
Merged

More precise WAR for resize vectorization#4305
naoyam merged 6 commits intomainfrom
fix_resize_vec

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Apr 24, 2025

This is a follow-up to #3906, which added a WAR to #3640. While it's safe, it turned out it's just too conservative. For example, here's a concat pattern appearing in the backward of Litgpt Llama RoPE:

Inputs:
  T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}]
  T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}]
  T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}]
Outputs:
  T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf]

%kernel_math {
T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}]
   = pad( T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}], {0, 0, 0, 0, 0, 2, 0, 0, 0, 0} )
i31 = 0 + 4;
T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}]
   = pad( T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}], {0, 0, 0, 0, i31, 1, 0, 0, 0, 0} )
i47 = i31 + 1;
T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}]
   = pad( T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}], {0, 0, 0, 0, i47, 0, 0, 0, 0, 0} )
T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}]
   = cat( T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}], T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}], T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}], 2 )
T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}]
   = Set.Permute( T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}], cache_op=Streaming )
T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf] = view( T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}] )
} // %kernel_math

This is currently taken by the pointwise scheduler, which attempts to vectorize the innermost ID of the output (i.e., iS52{6144}). Since the resize ops of the three pad ops are reachable from iS52, the WAR of #3640 simply takes them into consideration by calculating gcd with the left and right expand factors. In this case, since there's an expand factor of 1, the resulting vectorization factor is also just 1, which is clearly not what we want. Here, while the resized ID itself is not vectorizable due to the expand factor of 1, all of the resized tensors have large enough inner IDs that should allow the maximum vectorization.

To make the WAR a little less conservative, this PR also checks if the constraint by a Resize expr may be missed by the vectorization analysis. In the above case, that should not happen as there's only one path through each of the resize-based tensor ops.

This change is still not able to eliminate false positives completely. See one of the new tests that is currently disabled.

The codediff results all seem to make sense. http://nv/eFb. Previously some of the tests did not have vectorization due to the WAR, which is relaxed in this PR and allows some vectorization.

@github-actions
Copy link

github-actions bot commented Apr 24, 2025

Review updated until commit c01a3b8

Description

  • Improved WAR for vectorizing through resized iter domains

  • Added precise check for resize expr reachability

  • Enhanced test cases for vectorization with resize


Changes walkthrough 📝

Relevant files
Enhancement
vectorize_helper.cpp
Enhance resize vectorization WAR                                                 

csrc/scheduler/vectorize_helper.cpp

  • Added CanSkipResize class for permissive BFS traversal
  • Updated getResizeVectorizationFactors to use CanSkipResize
  • Improved logic to collect resize factors
  • +85/-64 
    Tests
    test_resize.cpp
    Add precise vectorization tests                                                   

    tests/cpp/test_resize.cpp

  • Renamed VectorizeSliceMultiplePaths to
    VectorizeInnerSliceMultiplePaths
  • Added DISABLED_VectorizeOuterSliceMultiplePaths test
  • Added VectorizeOuterPad test
  • +90/-1   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Logic Error

    The CanSkipResize::run function is called with resize as an argument, but inside the function, resize is redefined as a Resize* from logical_id->definition(). This could lead to incorrect behavior or segmentation faults if logical_id->definition() does not return a Resize*.

    for (auto resize : resize_based_ops) {
      auto resize_out_tv = resize->output(0)->as<TensorView>();
      for (const auto logical_id : resize_out_tv->getLogicalDomain()) {
        auto resize = dynamic_cast<Resize*>(logical_id->definition());
        if (resize == nullptr) {
          continue;
        }
    Redundant Code

    The resize variable is defined twice in the loop, which is redundant and can be confusing. The outer resize should be used directly without redefining it inside the loop.

    for (auto resize : resize_based_ops) {
      auto resize_out_tv = resize->output(0)->as<TensorView>();
      for (const auto logical_id : resize_out_tv->getLogicalDomain()) {
        auto resize = dynamic_cast<Resize*>(logical_id->definition());
        if (resize == nullptr) {
          continue;
        }
    Test Naming

    The test VectorizeInnerSliceMultiplePaths is named similarly to VectorizeSliceMultiplePaths, which might cause confusion. Consider renaming it to better reflect its purpose.

    // one of the paths from tv6 to tv0 is considered.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 24, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 24, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 25, 2025

    !test --diff

    EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2);
    }

    // The current analysis is not precise enough to pass this test
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    In this test, tv0 is resized in two different ways. The spanning-tree based analysis is not guaranteed to correctly identify the vectorization constraint.

    The WAR when applied to this case is still too conservative.

    @naoyam naoyam marked this pull request as ready for review April 25, 2025 05:11
    @naoyam naoyam added the rope label Apr 25, 2025
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 25, 2025

    !build

    @naoyam naoyam requested a review from jjsjann123 April 25, 2025 05:25
    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM, glad to see that we are addressing the earlier comment


    // Check if vectorization is properly applied even when a resized ID
    // is reachable from vectorized IDs. Pattern extracted from Litgpt
    // LLama RoPE backward.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nitpick on comment.

    This is the case where it's safe to skip the additional resize check. So that means the resized ID is NOT reachable from vectorized IDs.

    }
    max_vec_size = std::gcd(max_vec_size, inferred_val.as<int64_t>());
    auto inferred_val_int = inferred_val.as<int64_t>();
    if (inferred_val_int == 0) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    this is for dynamic resize extents that would be 0?

    resize_in_out_groups.pushBack(graph.toGroup(resize->out()));
    CanSkipResize bfs(graph, ref_groups, resize_in_out_groups, resize);
    bfs.traverse();
    return bfs.allToNodesVisited();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    qq: here we are calling allToNodesVisited()? but the init function below has /*require_all_to_visited=*/false,, so we are returning true here as long as a single node is visited in the target, which I think is the right behavior.

    But the function name is somewhat confusing.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No, the traversal should continue until no further progress is made. The require_all_to_visited flag means it's considered an error if not all of the to nodes were not able to reach.

    Here, we just want to check all of the to nodes are reachable. It isn't an error even if not.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    so we are returning true here as long as a single node is visited in the target, which I think is the right behavior.

    sorry I got confused myself. This returns true indicating it's safe to skip the check. So allToNodesVisited is the proper name for the function.

    Thanks for elaborating on this one.

    @naoyam naoyam merged commit 07effe8 into main Apr 25, 2025
    16 checks passed
    @naoyam naoyam deleted the fix_resize_vec branch April 25, 2025 19:36
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants