Skip to content

Resize scheduler vectorization WAR#3955

Merged
naoyam merged 6 commits intomainfrom
do_not_cancel_innermost_reshape
Feb 25, 2025
Merged

Resize scheduler vectorization WAR#3955
naoyam merged 6 commits intomainfrom
do_not_cancel_innermost_reshape

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Feb 24, 2025

This is a WAR for an issue with the vectorization by the resize scheduler (unrelated to #3640).

#3693 introduced a reordering optimization for the resize scheduler that attempted to minimize strides in read accesses of fusion inputs by canceling reshapes. It turned out it can result in conflicts with vectorization. The scheduler uses the fusion input as the reference of the vectorization analysis, assuming any reshape is canceled, which is not always the case.

So, in this PR, the vectorization analysis is changed to use the reference output. However, that isn't enough since when a resize is indeed canceled, the analysis should actually be done using the pre-reshape shape.

To workaround that, this PR also adds a flag to disable canceling reshapes that use innermost logical IDs. This should make sure it's always valid to use the fusion output as the reference of the vectorization analysis.

This is an ad-hoc WAR but should be good enough for the RoPE cases. The real problem is a bit inter-twinned here, and I'm not attempting to address it completely in this PR.

@naoyam
Copy link
Collaborator Author

naoyam commented Feb 24, 2025

!test --diff

@github-actions
Copy link

github-actions bot commented Feb 24, 2025

Review updated until commit 7583433

Description

  • Adjust vectorization reference to fusion output

  • Add flag to skip canceling reshapes with innermost IDs

  • Update tests to validate vectorization behavior

  • Improve comments and documentation for clarity


Changes walkthrough 📝

Relevant files
Enhancement
resize.cpp
Update vectorization reference and reshape cancellation   

csrc/scheduler/resize.cpp

  • Changed vectorization reference from largest input to fusion output
  • Updated cancelReshapeInLoopDomains call to skip innermost IDs
  • +4/-8     
    loop_domain_scheduler.cpp
    Add skip_innermost_id parameter to reshape cancellation   

    csrc/scheduler/tools/loop_domain_scheduler.cpp

  • Added skip_innermost_id parameter to cancelReshapeInLoopDomains
  • Updated logic to skip reshapes involving innermost logical IDs
  • Improved comments for clarity
  • +25/-2   
    loop_domain_scheduler.h
    Update function signature for reshape cancellation             

    csrc/scheduler/tools/loop_domain_scheduler.h

  • Updated cancelReshapeInLoopDomains function signature to include
    skip_innermost_id
  • Improved comments to explain the new parameter
  • +7/-1     
    Tests
    test_resize.cpp
    Update and add vectorization tests                                             

    tests/cpp/test_resize.cpp

  • Updated test cases to use new shape values
  • Added new tests for vectorization with reshapes
  • +82/-2   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Vectorization Logic

    The change in vectorization logic to use the reference output tensor may need further validation to ensure it doesn't introduce new issues or regressions in other scenarios.

    runtime_info,
    ref_tv,
    data_cache,
    (int64_t)ref_tv->getLogicalDomain().size() - 1,
    {});
    Reshape Cancellation

    The addition of the skip_innermost_id flag in cancelReshapeInLoopDomains should be carefully evaluated to ensure it doesn't inadvertently prevent valid reshape cancellations in other cases.

    Fusion* fusion = from_tv->fusion();
    IdModel id_model(fusion, /*build_graphs=*/false);
    id_model.buildExactGraph();
    const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT);
    
    // Reshapes producing these IDs should not be cancelled
    ValGroups reshape_dependent_ids;
    for (const ExprGroup& expr_g :
         exact_graph.disjointExprSets().disjointSets()) {
      if (expr_g->front()->isA<Resize>()) {
        reshape_dependent_ids.pushBack(exact_graph.inputGroups(expr_g));
      }
    }
    
    for (const ValGroup& val_g : exact_graph.disjointValSets().disjointSets()) {
      if (std::any_of(val_g->begin(), val_g->end(), [](Val* val) {
            NVF_ERROR(val->isA<IterDomain>());
            return val->as<IterDomain>()->isReduction();
          })) {
        reshape_dependent_ids.pushBack(val_g);
      }
    }
    
    Test Coverage

    While new tests have been added, it's important to ensure that the test cases cover a wide range of scenarios to validate the effectiveness of the WAR and prevent future regressions.

        }
    
        EXPECT_NE(has_resize, has_index_op);
      }
    }
    
    // Split-based reshape followed by a slice. The reshape is not
    // cancelable. The vectorization factor based on the innermost logical
    // ID of the input is not a valid factor as the fusion is scheduled
    // based on the post-reshape shape.
    TEST_F(ResizeTest, VectorizeInnermostWithReshapeSplit) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      std::vector<int64_t> shape1{128L * 16L};
      std::vector<int64_t> shape2{shape1[0] / 2L, 2L};
    
      auto tv0 = makeContigConcreteTensor(shape1);
      fusion.addInput(tv0);
    
      auto tv1 = sin(tv0);
      auto tv2 = reshape(tv1, shape1, shape2);
      auto tv3 = slice(
          tv2,
          {{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)},
           {IrBuilder::create<Val>(0L), IrBuilder::create<Val>(shape2[1])}});
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn(shape1, options);
    
      auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0});
      testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__);
    
      // Should be vector by a factor of 2 because the resize scheduler
      // only uses the innermost logical ID, and the extent of the output
      // tensor is just 2. Before PR #3955, the resize scheduler
      // attempted to vectorize by 4. Note that the slice op itself does
      // not matter for the vectorization as the sliced ID is not involved
      // in the vectorization.
      EXPECT_EQ(
          tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize);
      EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 2);
    }
    
    // Merge-based reshape followed by a slice. The reshape is
    // cancelable. If the output is used as the reference but the reshape
    // is canceled, the valid vectorization factor should be 2. The WAR of
    // PR #3955 gives up canceling any reshape that involves innermost
    // logical IDs to avoid this inconsistency.
    TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) {
      auto fusion_ptr = std::make_unique<Fusion>();
      auto& fusion = *fusion_ptr;
      FusionGuard fg(fusion_ptr.get());
    
      std::vector<int64_t> shape2{16, 128L * 16L};
      std::vector<int64_t> shape1{16, shape2[1] / 2L, 2L};
    
      auto tv0 = makeContigConcreteTensor(shape1);
      fusion.addInput(tv0);
    
      auto tv1 = sin(tv0);
      // [16, 128 * 16 / 2, 2] -> [16, 128 * 16]. Cancelable reshape.
      auto tv2 = reshape(tv1, shape1, shape2);
      auto tv3 = slice(
          tv2,
          {{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)},
           {IrBuilder::create<Val>(0L), IrBuilder::create<Val>(shape2[1])}});
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn(shape1, options);
    
      auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0});
      testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__);
    
      // Should be vector by a factor of 4. If the reshape were canceled,
      // it should have been 2, but in this case since it involves the
      // innermost logical ID of tv2, it is not canceled, thus
      // vectorization by 4 should be chosen.
      EXPECT_EQ(
          tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize);
      EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 4);
    }
    
    } // namespace nvfuser

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 24, 2025

    !test --diff

    @naoyam naoyam marked this pull request as ready for review February 24, 2025 23:38
    @naoyam naoyam changed the title Resize scheduler vectorization fix Resize scheduler vectorization WAR Feb 24, 2025
    @naoyam naoyam requested a review from jjsjann123 February 24, 2025 23:39
    const std::vector<IterDomain*>& output_ids_;
    };

    // Replay a given IterDomain transform expression on the loop domain
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Unrelated to this PR, but this is a comment I had for a past PR but forgot to add.

    // When the direction is forward, the TensorView transform
    // APIs, e.g., TensorView::split, can be used, which doesn't need
    // to use TensorView::setLoopDomain.
    // to use TensorView::setLoopDomain. This is important as
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Unrelated to this PR, but this is a comment I had for a past PR but forgot to add.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 25, 2025

    !test --diff

    FusionGuard fg(fusion_ptr.get());

    std::vector<int64_t> shape({-1, 100});
    std::vector<int64_t> shape({-1, 128});
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is a WAR for the other vectorization issue.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is the other vectorization issue referring to #3640 ?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yes

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Looks straightforward. Stamping.

    FusionGuard fg(fusion_ptr.get());

    std::vector<int64_t> shape({-1, 100});
    std::vector<int64_t> shape({-1, 128});
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is the other vectorization issue referring to #3640 ?

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 25, 2025

    I just realized disabling the cancelation does indeed matter for Mistral backward. I'll merge this for correctness and will think about the Mistral case later.

    @naoyam naoyam merged commit a7f13cb into main Feb 25, 2025
    58 of 60 checks passed
    @naoyam naoyam deleted the do_not_cancel_innermost_reshape branch February 25, 2025 18:05
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants