-
Notifications
You must be signed in to change notification settings - Fork 79
Resize scheduler vectorization WAR #3955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,6 +99,8 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch { | |
| const std::vector<IterDomain*>& output_ids_; | ||
| }; | ||
|
|
||
| // Replay a given IterDomain transform expression on the loop domain | ||
| // of a given tensor using specified loop IDs as its inputs. | ||
| class ReplayForwardTransformOnLoopDomain : OptInConstDispatch { | ||
| public: | ||
| static void replayAs( | ||
|
|
@@ -541,7 +543,11 @@ void scheduleLoopDomainsBy( | |
|
|
||
| // When the direction is forward, the TensorView transform | ||
| // APIs, e.g., TensorView::split, can be used, which doesn't need | ||
| // to use TensorView::setLoopDomain. | ||
| // to use TensorView::setLoopDomain. This is important as | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to this PR, but this is a comment I had for a past PR but forgot to add. |
||
| // setLoopDomain may result in losing extra IDs added by prior | ||
| // scheduleLoopDomain calls, which was indeed the case with the | ||
| // Llama 3 RoPE backward (see also | ||
| // https://github.com/NVIDIA/Fuser/issues/3571). | ||
| if (replay_dir_tv == Direction::Forward) { | ||
| ReplayForwardTransformOnLoopDomain::replayAs(tv, input_ids, transform); | ||
| continue; | ||
|
|
@@ -582,7 +588,7 @@ void scheduleLoopDomainsBy( | |
| return; | ||
| } | ||
|
|
||
| void cancelReshapeInLoopDomains(TensorView* from_tv) { | ||
| void cancelReshapeInLoopDomains(TensorView* from_tv, bool skip_innermost_id) { | ||
| Fusion* fusion = from_tv->fusion(); | ||
| IdModel id_model(fusion, /*build_graphs=*/false); | ||
| id_model.buildExactGraph(); | ||
|
|
@@ -677,13 +683,30 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) { | |
| {reshape_out->getLogicalDomain().begin(), | ||
| reshape_out->getLogicalDomain().end()}); | ||
|
|
||
| std::unordered_set<Expr*> reshape_exprs_with_innermost_logical_id_set; | ||
| if (skip_innermost_id) { | ||
| auto reshape_exprs_with_innermost_logical_id = | ||
| DependencyCheck::getAllExprsBetween( | ||
| {reshape_out->getRootDomain().begin(), | ||
| reshape_out->getRootDomain().end()}, | ||
| {reshape_out->getLogicalDomain().back()}); | ||
| reshape_exprs_with_innermost_logical_id_set = { | ||
| reshape_exprs_with_innermost_logical_id.begin(), | ||
| reshape_exprs_with_innermost_logical_id.end()}; | ||
| } | ||
|
|
||
| auto reshape_out_loop_domain = reshape_out->getLoopDomain(); | ||
|
|
||
| for (auto reshape_exprs_it = reshape_exprs.rbegin(); | ||
| reshape_exprs_it != reshape_exprs.rend(); | ||
| ++reshape_exprs_it) { | ||
| auto reshape_expr = *reshape_exprs_it; | ||
|
|
||
| if (skip_innermost_id && | ||
| reshape_exprs_with_innermost_logical_id_set.count(reshape_expr)) { | ||
| continue; | ||
| } | ||
|
|
||
| // If any of the output IDs of reshape_expr is not found in | ||
| // cancellable_ids, that means the expr cannot be cancelled. | ||
| if (std::any_of( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4755,7 +4755,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { | |
| Fusion& fusion = *fusion_ptr; | ||
| FusionGuard fg(fusion_ptr.get()); | ||
|
|
||
| std::vector<int64_t> shape({-1, 100}); | ||
| std::vector<int64_t> shape({-1, 128}); | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a WAR for the other vectorization issue.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
|
|
||
| EnableOptionsGuard enable_options_guard; | ||
| EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); | ||
|
|
@@ -4781,7 +4781,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { | |
| auto tv5 = cat({tv4, tv2}, 1); | ||
|
|
||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||
| auto t0 = at::randn({16, 100}, options); | ||
| auto t0 = at::randn({16, 128}, options); | ||
|
|
||
| fusion.addOutput(tv5); | ||
|
|
||
|
|
@@ -5822,4 +5822,84 @@ TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) { | |
| } | ||
| } | ||
|
|
||
| // Split-based reshape followed by a slice. The reshape is not | ||
| // cancelable. The vectorization factor based on the innermost logical | ||
| // ID of the input is not a valid factor as the fusion is scheduled | ||
| // based on the post-reshape shape. | ||
| TEST_F(ResizeTest, VectorizeInnermostWithReshapeSplit) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| auto& fusion = *fusion_ptr; | ||
| FusionGuard fg(fusion_ptr.get()); | ||
|
|
||
| std::vector<int64_t> shape1{128L * 16L}; | ||
| std::vector<int64_t> shape2{shape1[0] / 2L, 2L}; | ||
|
|
||
| auto tv0 = makeContigConcreteTensor(shape1); | ||
| fusion.addInput(tv0); | ||
|
|
||
| auto tv1 = sin(tv0); | ||
| auto tv2 = reshape(tv1, shape1, shape2); | ||
| auto tv3 = slice( | ||
| tv2, | ||
| {{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)}, | ||
| {IrBuilder::create<Val>(0L), IrBuilder::create<Val>(shape2[1])}}); | ||
| fusion.addOutput(tv3); | ||
|
|
||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||
| auto t0 = at::randn(shape1, options); | ||
|
|
||
| auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0}); | ||
| testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); | ||
|
|
||
| // Should be vector by a factor of 2 because the resize scheduler | ||
| // only uses the innermost logical ID, and the extent of the output | ||
| // tensor is just 2. Before PR #3955, the resize scheduler | ||
| // attempted to vectorize by 4. Note that the slice op itself does | ||
| // not matter for the vectorization as the sliced ID is not involved | ||
| // in the vectorization. | ||
| EXPECT_EQ( | ||
| tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize); | ||
| EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 2); | ||
| } | ||
|
|
||
| // Merge-based reshape followed by a slice. The reshape is | ||
| // cancelable. If the output is used as the reference but the reshape | ||
| // is canceled, the valid vectorization factor should be 2. The WAR of | ||
| // PR #3955 gives up canceling any reshape that involves innermost | ||
| // logical IDs to avoid this inconsistency. | ||
| TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| auto& fusion = *fusion_ptr; | ||
| FusionGuard fg(fusion_ptr.get()); | ||
|
|
||
| std::vector<int64_t> shape2{16, 128L * 16L}; | ||
| std::vector<int64_t> shape1{16, shape2[1] / 2L, 2L}; | ||
|
|
||
| auto tv0 = makeContigConcreteTensor(shape1); | ||
| fusion.addInput(tv0); | ||
|
|
||
| auto tv1 = sin(tv0); | ||
| // [16, 128 * 16 / 2, 2] -> [16, 128 * 16]. Cancelable reshape. | ||
| auto tv2 = reshape(tv1, shape1, shape2); | ||
| auto tv3 = slice( | ||
| tv2, | ||
| {{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)}, | ||
| {IrBuilder::create<Val>(0L), IrBuilder::create<Val>(shape2[1])}}); | ||
| fusion.addOutput(tv3); | ||
|
|
||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||
| auto t0 = at::randn(shape1, options); | ||
|
|
||
| auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0}); | ||
| testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); | ||
|
|
||
| // Should be vector by a factor of 4. If the reshape were canceled, | ||
| // it should have been 2, but in this case since it involves the | ||
| // innermost logical ID of tv2, it is not canceled, thus | ||
| // vectorization by 4 should be chosen. | ||
| EXPECT_EQ( | ||
| tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize); | ||
| EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 4); | ||
| } | ||
|
|
||
| } // namespace nvfuser | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this PR, but this is a comment I had for a past PR but forgot to add.